From e852aa7882d93d9bafe3a6438506feb5a440bf87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Loyck=20Andres?= Date: Mon, 4 Oct 2021 12:12:45 +0200 Subject: [PATCH 1/2] Add scala.annotation.MainAnnotation See `docs/_docs/reference/experimental/main-annotation.md` --- .../dotty/tools/dotc/ast/MainProxies.scala | 350 ++++++++++++++++-- .../dotty/tools/dotc/core/Definitions.scala | 8 + .../src/dotty/tools/dotc/core/StdNames.scala | 4 + .../src/dotty/tools/dotc/typer/Checking.scala | 7 +- .../src/dotty/tools/dotc/typer/Typer.scala | 2 +- .../reference/experimental/main-annotation.md | 91 +++++ docs/sidebar.yml | 1 + .../src/scala/annotation/MainAnnotation.scala | 119 ++++++ project/MiMaFilters.scala | 12 +- .../referenceReplacements/sidebar.yml | 1 + .../neg/main-annotation-mainannotation.scala | 3 + tests/run/main-annotation-example.check | 3 + tests/run/main-annotation-example.scala | 59 +++ .../main-annotation-homemade-annot-1.check | 4 + .../main-annotation-homemade-annot-1.scala | 45 +++ .../main-annotation-homemade-annot-2.check | 11 + .../main-annotation-homemade-annot-2.scala | 50 +++ .../main-annotation-homemade-annot-3.check | 1 + .../main-annotation-homemade-annot-3.scala | 24 ++ .../main-annotation-homemade-annot-4.check | 1 + .../main-annotation-homemade-annot-4.scala | 24 ++ .../main-annotation-homemade-annot-5.check | 2 + .../main-annotation-homemade-annot-5.scala | 26 ++ .../main-annotation-homemade-annot-6.check | 28 ++ .../main-annotation-homemade-annot-6.scala | 63 ++++ tests/run/main-annotation-newMain.scala | 307 +++++++++++++++ 26 files changed, 1217 insertions(+), 29 deletions(-) create mode 100644 docs/_docs/reference/experimental/main-annotation.md create mode 100644 library/src/scala/annotation/MainAnnotation.scala create mode 100644 tests/neg/main-annotation-mainannotation.scala create mode 100644 tests/run/main-annotation-example.check create mode 100644 tests/run/main-annotation-example.scala create mode 100644 tests/run/main-annotation-homemade-annot-1.check create mode 100644 tests/run/main-annotation-homemade-annot-1.scala create mode 100644 tests/run/main-annotation-homemade-annot-2.check create mode 100644 tests/run/main-annotation-homemade-annot-2.scala create mode 100644 tests/run/main-annotation-homemade-annot-3.check create mode 100644 tests/run/main-annotation-homemade-annot-3.scala create mode 100644 tests/run/main-annotation-homemade-annot-4.check create mode 100644 tests/run/main-annotation-homemade-annot-4.scala create mode 100644 tests/run/main-annotation-homemade-annot-5.check create mode 100644 tests/run/main-annotation-homemade-annot-5.scala create mode 100644 tests/run/main-annotation-homemade-annot-6.check create mode 100644 tests/run/main-annotation-homemade-annot-6.scala create mode 100644 tests/run/main-annotation-newMain.scala diff --git a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala index 183854f3aede..01ae50850b57 100644 --- a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala +++ b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala @@ -2,30 +2,40 @@ package dotty.tools.dotc package ast import core._ -import Symbols._, Types._, Contexts._, Flags._, Constants._ -import StdNames.nme - -/** Generate proxy classes for @main functions. - * A function like - * - * @main def f(x: S, ys: T*) = ... - * - * would be translated to something like - * - * import CommandLineParser._ - * class f { - * @static def main(args: Array[String]): Unit = - * try - * f( - * parseArgument[S](args, 0), - * parseRemainingArguments[T](args, 1): _* - * ) - * catch case err: ParseError => showError(err) - * } - */ +import Symbols._, Types._, Contexts._, Decorators._, util.Spans._, Flags._, Constants._ +import StdNames.{nme, tpnme} +import ast.Trees._ +import Names.Name +import Comments.Comment +import NameKinds.DefaultGetterName +import Annotations.Annotation + object MainProxies { - def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { + /** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */ + def proxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { + mainAnnotationProxies(stats) ++ mainProxies(stats) + } + + /** Generate proxy classes for @main functions. + * A function like + * + * @main def f(x: S, ys: T*) = ... + * + * would be translated to something like + * + * import CommandLineParser._ + * class f { + * @static def main(args: Array[String]): Unit = + * try + * f( + * parseArgument[S](args, 0), + * parseRemainingArguments[T](args, 1): _* + * ) + * catch case err: ParseError => showError(err) + * } + */ + private def mainProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { import tpd._ def mainMethods(stats: List[Tree]): List[Symbol] = stats.flatMap { case stat: DefDef if stat.symbol.hasAnnotation(defn.MainAnnot) => @@ -39,7 +49,7 @@ object MainProxies { } import untpd._ - def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = { + private def mainProxy(mainFun: Symbol)(using Context): List[TypeDef] = { val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot).get.tree.span def pos = mainFun.sourcePos val argsRef = Ident(nme.args) @@ -116,4 +126,298 @@ object MainProxies { } result } + + private type DefaultValueSymbols = Map[Int, Symbol] + private type ParameterAnnotationss = Seq[Seq[Annotation]] + + /** + * Generate proxy classes for main functions. + * A function like + * + * /** + * * Lorem ipsum dolor sit amet + * * consectetur adipiscing elit. + * * + * * @param x my param x + * * @param ys all my params y + * */ + * @myMain(80) def f( + * @myMain.Alias("myX") x: S, + * ys: T* + * ) = ... + * + * would be translated to something like + * + * final class f { + * static def main(args: Array[String]): Unit = { + * val cmd = new myMain(80).command( + * info = new CommandInfo( + * name = "f", + * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.", + * parameters = Seq( + * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))) + * new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq()) + * ) + * ) + * args = args + * ) + * + * val args0: () => S = cmd.argGetter[S](0, None) + * val args1: () => Seq[T] = cmd.varargGetter[T] + * + * cmd.run(() => f(args0(), args1()*)) + * } + * } + */ + private def mainAnnotationProxies(stats: List[tpd.Tree])(using Context): List[untpd.Tree] = { + import tpd._ + + /** + * Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this + * point of the compilation, they must be explicitly passed by [[mainProxy]]. + */ + def defaultValueSymbols(scope: Tree, funSymbol: Symbol): DefaultValueSymbols = + scope match { + case TypeDef(_, template: Template) => + template.body.flatMap((_: Tree) match { + case dd: DefDef if dd.name.is(DefaultGetterName) && dd.name.firstPart == funSymbol.name => + val DefaultGetterName.NumberedInfo(index) = dd.name.info + List(index -> dd.symbol) + case _ => Nil + }).toMap + case _ => Map.empty + } + + /** Computes the list of main methods present in the code. */ + def mainMethods(scope: Tree, stats: List[Tree]): List[(Symbol, ParameterAnnotationss, DefaultValueSymbols, Option[Comment])] = stats.flatMap { + case stat: DefDef => + val sym = stat.symbol + sym.annotations.filter(_.matches(defn.MainAnnotationClass)) match { + case Nil => + Nil + case _ :: Nil => + val paramAnnotations = stat.paramss.flatMap(_.map( + valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation)) + )) + (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil + case mainAnnot :: others => + report.error(s"method cannot have multiple main annotations", mainAnnot.tree) + Nil + } + case stat @ TypeDef(_, impl: Template) if stat.symbol.is(Module) => + mainMethods(stat, impl.body) + case _ => + Nil + } + + // Assuming that the top-level object was already generated, all main methods will have a scope + mainMethods(EmptyTree, stats).flatMap(mainAnnotationProxy) + } + + private def mainAnnotationProxy(mainFun: Symbol, paramAnnotations: ParameterAnnotationss, defaultValueSymbols: DefaultValueSymbols, docComment: Option[Comment])(using Context): Option[TypeDef] = { + val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass).get + def pos = mainFun.sourcePos + + val documentation = new Documentation(docComment) + + /** () => value */ + def unitToValue(value: Tree): Tree = + val defDef = DefDef(nme.ANON_FUN, List(Nil), TypeTree(), value) + Block(defDef, Closure(Nil, Ident(nme.ANON_FUN), EmptyTree)) + + /** Generate a list of trees containing the ParamInfo instantiations. + * + * A ParamInfo has the following shape + * ``` + * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))) + * ``` + */ + def parameterInfos(mt: MethodType): List[Tree] = + extension (tree: Tree) def withProperty(sym: Symbol, args: List[Tree]) = + Apply(Select(tree, sym.name), args) + + for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield + val param = paramName.toString + val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias + val paramType = paramType0.dealias + + val paramTypeStr = formal.dealias.typeSymbol.owner.showFullName + "." + paramType.show + val hasDefault = defaultValueSymbols.contains(idx) + val isRepeated = formal.isRepeatedParam + val paramDoc = documentation.argDocs.getOrElse(param, "") + val paramAnnots = + val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList + Apply(ref(defn.SeqModule.termRef), annotationTrees) + + val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc) + .map(value => Literal(Constant(value))) + + New(TypeTree(defn.MainAnnotationParameterInfo.typeRef), List(constructorArgs :+ paramAnnots)) + + end parameterInfos + + /** + * Creates a list of references and definitions of arguments. + * The goal is to create the + * `val args0: () => S = cmd.argGetter[S](0, None)` + * part of the code. + */ + def argValDefs(mt: MethodType): List[ValDef] = + for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield + val argName = nme.args ++ idx.toString + val isRepeated = formal.isRepeatedParam + val formalType = if isRepeated then formal.argTypes.head else formal + val getterName = if isRepeated then nme.varargGetter else nme.argGetter + val defaultValueGetterOpt = defaultValueSymbols.get(idx) match + case None => ref(defn.NoneModule.termRef) + case Some(dvSym) => + val value = unitToValue(ref(dvSym.termRef)) + Apply(ref(defn.SomeClass.companionModule.termRef), value) + val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterName), TypeTree(formalType) :: Nil) + val argGetter = + if isRepeated then argGetter0 + else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt)) + + ValDef(argName, TypeTree(), argGetter) + end argValDefs + + + /** Create a list of argument references that will be passed as argument to the main method. + * `args0`, ...`argn*` + */ + def argRefs(mt: MethodType): List[Tree] = + for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield + val argRef = Apply(Ident(nme.args ++ idx.toString), Nil) + if formal.isRepeatedParam then repeated(argRef) else argRef + end argRefs + + + /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */ + def instantiateAnnotation(annot: Annotation): Tree = + val argss = { + def recurse(t: tpd.Tree, acc: List[List[Tree]]): List[List[Tree]] = t match { + case Apply(t, args: List[tpd.Tree]) => recurse(t, extractArgs(args) :: acc) + case _ => acc + } + + def extractArgs(args: List[tpd.Tree]): List[Tree] = + args.flatMap { + case Typed(SeqLiteral(varargs, _), _) => varargs.map(arg => TypedSplice(arg)) + case arg: Select if arg.name.is(DefaultGetterName) => Nil // Ignore default values, they will be added later by the compiler + case arg => List(TypedSplice(arg)) + } + + recurse(annot.tree, Nil) + } + + New(TypeTree(annot.symbol.typeRef), argss) + end instantiateAnnotation + + def generateMainClass(mainCall: Tree, args: List[Tree], parameterInfos: List[Tree]): TypeDef = + val cmdInfo = + val nameTree = Literal(Constant(mainFun.showName)) + val docTree = Literal(Constant(documentation.mainDoc)) + val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos) + New(TypeTree(defn.MainAnnotationCommandInfo.typeRef), List(List(nameTree, docTree, paramInfos))) + + val cmd = ValDef( + nme.cmd, + TypeTree(), + Apply( + Select(instantiateAnnotation(mainAnnot), nme.command), + List(cmdInfo, Ident(nme.args)) + ) + ) + val run = Apply(Select(Ident(nme.cmd), nme.run), mainCall) + val body = Block(cmdInfo :: cmd :: args, run) + val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree) + .withFlags(Param) + /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol. + * The annotations will be retype-checked in another scope that may not have the same imports. + */ + def insertTypeSplices = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = tree match + case tree: tpd.Ident @unchecked => TypedSplice(tree) + case tree => super.transform(tree) + } + val annots = mainFun.annotations + .filterNot(_.matches(defn.MainAnnotationClass)) + .map(annot => insertTypeSplices.transform(annot.tree)) + val mainMeth = DefDef(nme.main, (mainArg :: Nil) :: Nil, TypeTree(defn.UnitType), body) + .withFlags(JavaStatic) + .withAnnotations(annots) + val mainTempl = Template(emptyConstructor, Nil, Nil, EmptyValDef, mainMeth :: Nil) + val mainCls = TypeDef(mainFun.name.toTypeName, mainTempl) + .withFlags(Final | Invisible) + mainCls.withSpan(mainAnnot.tree.span.toSynthetic) + end generateMainClass + + if (!mainFun.owner.isStaticOwner) + report.error(s"main method is not statically accessible", pos) + None + else mainFun.info match { + case _: ExprType => + Some(generateMainClass(unitToValue(ref(mainFun.termRef)), Nil, Nil)) + case mt: MethodType => + if (mt.isImplicitMethod) + report.error(s"main method cannot have implicit parameters", pos) + None + else mt.resType match + case restpe: MethodType => + report.error(s"main method cannot be curried", pos) + None + case _ => + Some(generateMainClass(unitToValue(Apply(ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt))) + case _: PolyType => + report.error(s"main method cannot have type parameters", pos) + None + case _ => + report.error(s"main can only annotate a method", pos) + None + } + } + + /** A class responsible for extracting the docstrings of a method. */ + private class Documentation(docComment: Option[Comment]): + import util.CommentParsing._ + + /** The main part of the documentation. */ + lazy val mainDoc: String = _mainDoc + /** The parameters identified by @param. Maps from parameter name to its documentation. */ + lazy val argDocs: Map[String, String] = _argDocs + + private var _mainDoc: String = "" + private var _argDocs: Map[String, String] = Map() + + docComment match { + case Some(comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw + case None => + } + + private def cleanComment(raw: String): String = + var lines: Seq[String] = raw.trim.nn.split('\n').nn.toSeq + lines = lines.map(l => l.substring(skipLineLead(l, -1), l.length).nn.trim.nn) + var s = lines.foldLeft("") { + case ("", s2) => s2 + case (s1, "") if s1.last == '\n' => s1 // Multiple newlines are kept as single newlines + case (s1, "") => s1 + '\n' + case (s1, s2) if s1.last == '\n' => s1 + s2 + case (s1, s2) => s1 + ' ' + s2 + } + s.replaceAll(raw"\[\[", "").nn.replaceAll(raw"\]\]", "").nn.trim.nn + + private def parseDocComment(raw: String): Unit = + // Positions of the sections (@) in the docstring + val tidx: List[(Int, Int)] = tagIndex(raw) + + // Parse main comment + var mainComment: String = raw.substring(skipLineLead(raw, 0), startTag(raw, tidx)).nn + _mainDoc = cleanComment(mainComment) + + // Parse arguments comments + val argsCommentsSpans: Map[String, (Int, Int)] = paramDocs(raw, "@param", tidx) + val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _)) + val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn }) + _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap + end Documentation } diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 5eb1ccb0f957..9e70102fa1dd 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -528,6 +528,8 @@ class Definitions { @tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType)) @tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length) @tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq) + @tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq") + @tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps") @tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format) @@ -853,6 +855,12 @@ class Definitions { @tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope") + @tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation") + @tu lazy val MainAnnotationCommandInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.CommandInfo") + @tu lazy val MainAnnotationParameterInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfo") + @tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation") + @tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command") + @tu lazy val CommandLineParserModule: Symbol = requiredModule("scala.util.CommandLineParser") @tu lazy val CLP_ParseError: ClassSymbol = CommandLineParserModule.requiredClass("ParseError").typeRef.symbol.asClass @tu lazy val CLP_parseArgument: Symbol = CommandLineParserModule.requiredMethod("parseArgument") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 1bf91bf69abe..bb5efcc01a5e 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -397,6 +397,7 @@ object StdNames { val applyOrElse: N = "applyOrElse" val args : N = "args" val argv : N = "argv" + val argGetter : N = "argGetter" val arrayClass: N = "arrayClass" val arrayElementClass: N = "arrayElementClass" val arrayType: N = "arrayType" @@ -427,6 +428,8 @@ object StdNames { val classOf: N = "classOf" val classType: N = "classType" val clone_ : N = "clone" + val cmd: N = "cmd" + val command: N = "command" val common: N = "common" val compiletime : N = "compiletime" val conforms_ : N = "$conforms" @@ -613,6 +616,7 @@ object StdNames { val fromOrdinal: N = "fromOrdinal" val values: N = "values" val view_ : N = "view" + val varargGetter : N = "varargGetter" val wait_ : N = "wait" val wildcardType: N = "wildcardType" val withFilter: N = "withFilter" diff --git a/compiler/src/dotty/tools/dotc/typer/Checking.scala b/compiler/src/dotty/tools/dotc/typer/Checking.scala index b7c65a30e7b4..1cce3fdea280 100644 --- a/compiler/src/dotty/tools/dotc/typer/Checking.scala +++ b/compiler/src/dotty/tools/dotc/typer/Checking.scala @@ -1351,12 +1351,13 @@ trait Checking { def checkAnnotApplicable(annot: Tree, sym: Symbol)(using Context): Boolean = !ctx.reporter.reportsErrorsFor { val annotCls = Annotations.annotClass(annot) + val concreteAnnot = Annotations.ConcreteAnnotation(annot) val pos = annot.srcPos - if (annotCls == defn.MainAnnot) { + if (annotCls == defn.MainAnnot || concreteAnnot.matches(defn.MainAnnotationClass)) { if (!sym.isRealMethod) - report.error(em"@main annotation cannot be applied to $sym", pos) + report.error(em"main annotation cannot be applied to $sym", pos) if (!sym.owner.is(Module) || !sym.owner.isStatic) - report.error(em"$sym cannot be a @main method since it cannot be accessed statically", pos) + report.error(em"$sym cannot be a main method since it cannot be accessed statically", pos) } // TODO: Add more checks here } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index ac8d6152812e..d915a35b88b4 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2602,7 +2602,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer pkg.moduleClass.info.decls.lookup(topLevelClassName).ensureCompleted() var stats1 = typedStats(tree.stats, pkg.moduleClass)._1 if (!ctx.isAfterTyper) - stats1 = stats1 ++ typedBlockStats(MainProxies.mainProxies(stats1))._1 + stats1 = stats1 ++ typedBlockStats(MainProxies.proxies(stats1))._1 cpy.PackageDef(tree)(pid1, stats1).withType(pkg.termRef) } case _ => diff --git a/docs/_docs/reference/experimental/main-annotation.md b/docs/_docs/reference/experimental/main-annotation.md new file mode 100644 index 000000000000..c87d143ed15f --- /dev/null +++ b/docs/_docs/reference/experimental/main-annotation.md @@ -0,0 +1,91 @@ +--- +layout: doc-page +title: "MainAnnotation" +--- + +`MainAnnotation` provides a generic way to define main annotations such as `@main`. + +When a users annotates a method with an annotation that extends `MainAnnotation` a class with a `main` method will be generated. The main method will contain the code needed to parse the command line arguments and run the application. + +```scala +/** Sum all the numbers + * + * @param first Fist number to sum + * @param rest The rest of the numbers to sum + */ +@myMain def sum(first: Int, rest: Int*): Int = first + rest.sum +``` + +```scala +object foo { + def main(args: Array[String]): Unit = { + + val cmd = new myMain().command( + info = new CommandInfo( + name = "sum", + documentation = "Sum all the numbers", + parameters = Seq( + new ParameterInfo("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()), + new ParameterInfo("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq()) + ) + ), + args = args + ) + val args0 = cmd.argGetter[Int](0, None) // using a parser of Int + val args1 = cmd.varargGetter[Int] // using a parser of Int + cmd.run(() => sum(args0(), args1()*)) + } +} +``` + +The implementation of the `main` method first instantiates the annotation and then creates a `Command`. +When creating the `Command`, the arguments can be checked and preprocessed. +Then it defines a series of argument getters calling `argGetter` for each parameter and `varargGetter` for the last one if it is a varargs. `argGetter` gets an optional lambda that computes the default argument. +Finally, the `run` method is called to run the application. It receives a by-name argument that contains the call the annotated method with the instantiations arguments (using the lambdas from `argGetter`/`varargGetter`). + + +Example of implementation of `myMain` that takes all arguments positionally. It used `util.CommandLineParser.FromString` and expects no default arguments. For simplicity, any errors in preprocessing or parsing results in crash. + +```scala +// Parser used to parse command line arguments +import scala.util.CommandLineParser.FromString[T] + +// Result type of the annotated method is Int +class myMain extends MainAnnotation: + import MainAnnotation.{ ParameterInfo, Command } + + /** A new command with arguments from `args` */ + def command(info: CommandInfo, args: Array[String]): Command[FromString, Int] = + if args.contains("--help") then + println(info.documentation) + // TODO: Print documentation of the parameters + System.exit(0) + assert(info.parameters.forall(!_.hasDefault), "Default arguments are not supported") + val (plainArgs, varargs) = + if info.parameters.last.isVarargs then + val numPlainArgs = info.parameters.length - 1 + assert(numPlainArgs <= args.length, "Not enough arguments") + (args.take(numPlainArgs), args.drop(numPlainArgs)) + else + assert(info.parameters.length <= args.length, "Not enough arguments") + assert(info.parameters.length >= args.length, "Too many arguments") + (args, Array.empty[String]) + new MyCommand(plainArgs, varargs) + + @experimental + class MyCommand(plainArgs: Seq[String], varargs: Seq[String]) extends Command[FromString, Int]: + + def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = + () => parser.fromString(plainArgs(idx)) + + def varargGetter[T](using parser: FromString[T]): () => Seq[T] = + () => varargs.map(arg => parser.fromString(arg)) + + def run(program: () => Int): Unit = + println("executing program") + val result = program() + println("result: " + result) + println("executed program") + end MyCommand +end myMain +``` diff --git a/docs/sidebar.yml b/docs/sidebar.yml index 0f6ed6bf935d..7c68120bfbc2 100644 --- a/docs/sidebar.yml +++ b/docs/sidebar.yml @@ -147,6 +147,7 @@ subsection: - page: reference/experimental/named-typeargs-spec.md - page: reference/experimental/numeric-literals.md - page: reference/experimental/explicit-nulls.md + - page: reference/experimental/main-annotation.md - page: reference/experimental/cc.md - page: reference/experimental/tupled-function.md - page: reference/syntax.md diff --git a/library/src/scala/annotation/MainAnnotation.scala b/library/src/scala/annotation/MainAnnotation.scala new file mode 100644 index 000000000000..6e30ee6f69a3 --- /dev/null +++ b/library/src/scala/annotation/MainAnnotation.scala @@ -0,0 +1,119 @@ +package scala.annotation + +/** MainAnnotation provides the functionality for a compiler-generated main class. + * It links a compiler-generated main method (call it compiler-main) to a user + * written main method (user-main). + * The protocol of calls from compiler-main is as follows: + * + * - create a `command` with the command line arguments, + * - for each parameter of user-main, a call to `command.argGetter`, + * or `command.varargGetter` if is a final varargs parameter, + * - a call to `command.run` with the closure of user-main applied to all arguments. + * + * Example: + * ```scala + * /** Sum all the numbers + * * + * * @param first Fist number to sum + * * @param rest The rest of the numbers to sum + * */ + * @myMain def sum(first: Int, rest: Int*): Int = first + rest.sum + * ``` + * generates + * ```scala + * object foo { + * def main(args: Array[String]): Unit = { + * val cmd = new myMain().command( + * info = new CommandInfo( + * name = "foo.main", + * documentation = "Sum all the numbers", + * parameters = Seq( + * new ParameterInfo("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum"), + * new ParameterInfo("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum") + * ) + * ) + * args = args + * ) + * val args0 = cmd.argGetter[Int](0, None) // using cmd.Parser[Int] + * val args1 = cmd.varargGetter[Int] // using cmd.Parser[Int] + * cmd.run(() => sum(args0(), args1()*)) + * } + * } + * ``` + * + */ +@experimental +trait MainAnnotation extends StaticAnnotation: + import MainAnnotation.{Command, CommandInfo} + + /** A new command with arguments from `args` + * + * @param info The information about the command (name, documentation and info about parameters) + * @param args The command line arguments + */ + def command(info: CommandInfo, args: Array[String]): Command[?, ?] + +end MainAnnotation + +@experimental +object MainAnnotation: + + /** A class representing a command to run + * + * @param Parser The class used for argument string parsing and arguments into a `T` + * @param Result The required result type of the main method. + * If this type is Any or Unit, any type will be accepted. + */ + trait Command[Parser[_], Result]: + + /** The getter for the `idx`th argument of type `T` + * + * @param idx The index of the argument + * @param defaultArgument Optional lambda to instantiate the default argument + */ + def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using Parser[T]): () => T + + /** The getter for a final varargs argument of type `T*` */ + def varargGetter[T](using Parser[T]): () => Seq[T] + + /** Run `program` if all arguments are valid if all arguments are valid + * + * @param program A function containing the call to the main method and instantiation of its arguments + */ + def run(program: () => Result): Unit + end Command + + /** Information about the main method + * + * @param name The name of the main method + * @param documentation The documentation of the main method without the `@param` documentation (see ParameterInfo.documentaion) + * @param parameters Information about the parameters of the main method + */ + final class CommandInfo( + val name: String, + val documentation: String, + val parameters: Seq[ParameterInfo], + ) + + /** Information about a parameter of a main method + * + * @param name The name of the parameter + * @param typeName The name of the parameter's type + * @param hasDefault If the parameter has a default argument + * @param isVarargs If the parameter is a varargs parameter (can only be true for the last parameter) + * @param documentation The documentation of the parameter (from `@param` documentation in the main method) + * @param annotations The annotations of the parameter that extend `ParameterAnnotation` + */ + final class ParameterInfo ( + val name: String, + val typeName: String, + val hasDefault: Boolean, + val isVarargs: Boolean, + val documentation: String, + val annotations: Seq[ParameterAnnotation], + ) + + /** Marker trait for annotations that will be included in the ParameterInfo annotations. */ + trait ParameterAnnotation extends StaticAnnotation + +end MainAnnotation diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index 2c4fd4992432..8bd16f134f57 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -3,13 +3,21 @@ import com.typesafe.tools.mima.core._ object MiMaFilters { val Library: Seq[ProblemFilter] = Seq( - - // Those are OK because user code is not allowed to inherit from Quotes: + // Experimental APIs that can be added in 3.2.0 or later + ProblemFilters.exclude[DirectMissingMethodProblem]("scala.runtime.Tuples.append"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.asQuotes"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#ClassDefModule.apply"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolModule.newClass"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.typeRef"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.termRef"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#TypeTreeModule.ref"), + + // Experimental `MainAnnotation` APIs. Can be added in 3.3.0 or later. + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$Command"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$CommandInfo"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$ParameterInfo"), + ProblemFilters.exclude[MissingClassProblem]("scala.annotation.MainAnnotation$ParameterAnnotation"), ) } diff --git a/project/resources/referenceReplacements/sidebar.yml b/project/resources/referenceReplacements/sidebar.yml index a8453449e73e..680b44d353d4 100644 --- a/project/resources/referenceReplacements/sidebar.yml +++ b/project/resources/referenceReplacements/sidebar.yml @@ -127,6 +127,7 @@ subsection: - page: reference/experimental/named-typeargs-spec.md - page: reference/experimental/numeric-literals.md - page: reference/experimental/explicit-nulls.md + - page: reference/experimental/main-annotation.md - page: reference/experimental/cc.md - page: reference/syntax.md - title: Language Versions diff --git a/tests/neg/main-annotation-mainannotation.scala b/tests/neg/main-annotation-mainannotation.scala new file mode 100644 index 000000000000..21e37d1779af --- /dev/null +++ b/tests/neg/main-annotation-mainannotation.scala @@ -0,0 +1,3 @@ +import scala.annotation.MainAnnotation + +@MainAnnotation def f(i: Int, n: Int) = () // error diff --git a/tests/run/main-annotation-example.check b/tests/run/main-annotation-example.check new file mode 100644 index 000000000000..97fcf11da08b --- /dev/null +++ b/tests/run/main-annotation-example.check @@ -0,0 +1,3 @@ +executing program +result: 28 +executed program diff --git a/tests/run/main-annotation-example.scala b/tests/run/main-annotation-example.scala new file mode 100644 index 000000000000..91036df44f57 --- /dev/null +++ b/tests/run/main-annotation-example.scala @@ -0,0 +1,59 @@ +import scala.annotation.* +import collection.mutable +import scala.util.CommandLineParser.FromString + +/** Sum all the numbers + * + * @param first Fist number to sum + * @param rest The rest of the numbers to sum + */ +@myMain def sum(first: Int, rest: Int*): Int = first + rest.sum + + +object Test: + def callMain(args: Array[String]): Unit = + val clazz = Class.forName("sum") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, args) + + def main(args: Array[String]): Unit = + callMain(Array("23", "2", "3")) +end Test + +@experimental +class myMain extends MainAnnotation: + import MainAnnotation.{ Command, CommandInfo, ParameterInfo } + + /** A new command with arguments from `args` */ + def command(info: CommandInfo, args: Array[String]): Command[FromString, Int] = + if args.contains("--help") then + println(info.documentation) + System.exit(0) + assert(info.parameters.forall(!_.hasDefault), "Default arguments are not supported") + val (plainArgs, varargs) = + if info.parameters.last.isVarargs then + val numPlainArgs = info.parameters.length - 1 + assert(numPlainArgs <= args.length, "Not enough arguments") + (args.take(numPlainArgs), args.drop(numPlainArgs)) + else + assert(info.parameters.length <= args.length, "Not enough arguments") + assert(info.parameters.length >= args.length, "Too many arguments") + (args, Array.empty[String]) + new MyCommand(plainArgs, varargs) + + @experimental + class MyCommand(plainArgs: Seq[String], varargs: Seq[String]) extends Command[FromString, Int]: + + def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = + () => parser.fromString(plainArgs(idx)) + + def varargGetter[T](using parser: FromString[T]): () => Seq[T] = + () => varargs.map(arg => parser.fromString(arg)) + + def run(program: () => Int): Unit = + println("executing program") + val result = program() + println("result: " + result) + println("executed program") + end MyCommand +end myMain diff --git a/tests/run/main-annotation-homemade-annot-1.check b/tests/run/main-annotation-homemade-annot-1.check new file mode 100644 index 000000000000..4b7ff457bb11 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-1.check @@ -0,0 +1,4 @@ +42 +42 +1 +2 diff --git a/tests/run/main-annotation-homemade-annot-1.scala b/tests/run/main-annotation-homemade-annot-1.scala new file mode 100644 index 000000000000..fabbc6348221 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-1.scala @@ -0,0 +1,45 @@ +import scala.concurrent._ +import scala.annotation.* +import scala.collection.mutable +import ExecutionContext.Implicits.global +import duration._ +import util.CommandLineParser.FromString + +@mainAwait def get(wait: Int): Future[Int] = Future{ + Thread.sleep(1000 * wait) + 42 +} + +@mainAwait def getMany(wait: Int*): Future[Int] = Future{ + Thread.sleep(1000 * wait.sum) + wait.length +} + +object Test: + def callMain(cls: String, args: Array[String]): Unit = + val clazz = Class.forName(cls) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, args) + + def main(args: Array[String]): Unit = + println(Await.result(get(1), Duration(2, SECONDS))) + callMain("get", Array("1")) + callMain("getMany", Array("1")) + callMain("getMany", Array("0", "1")) +end Test + +@experimental +class mainAwait(timeout: Int = 2) extends MainAnnotation: + import MainAnnotation.* + + // This is a toy example, it only works with positional args + def command(info: CommandInfo, args: Array[String]): Command[FromString, Future[Any]] = + new Command[FromString, Future[Any]]: + override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = + () => p.fromString(args(idx)) + + override def varargGetter[T](using p: FromString[T]): () => Seq[T] = + () => for i <- ((info.parameters.length-1) until args.length) yield p.fromString(args(i)) + + override def run(f: () => Future[Any]): Unit = println(Await.result(f(), Duration(timeout, SECONDS))) +end mainAwait diff --git a/tests/run/main-annotation-homemade-annot-2.check b/tests/run/main-annotation-homemade-annot-2.check new file mode 100644 index 000000000000..f57ec79b8dbd --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-2.check @@ -0,0 +1,11 @@ +I was run! +A +I was run! +A +I was run! +A +Here are some colors: +Purple smart, Blue fast, White fashion, Yellow quiet, Orange honest, Pink loud +This will be printed, but nothing more. +This will be printed, but nothing more. +This will be printed, but nothing more. diff --git a/tests/run/main-annotation-homemade-annot-2.scala b/tests/run/main-annotation-homemade-annot-2.scala new file mode 100644 index 000000000000..e2eecfbd6fcc --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-2.scala @@ -0,0 +1,50 @@ +import scala.collection.mutable +import scala.annotation.* +import util.CommandLineParser.FromString + +@myMain()("A") +def foo1(): Unit = println("I was run!") + +@myMain(0)("This should not be printed") +def foo2() = throw new Exception("This should not be run") + +@myMain(1)("Purple smart", "Blue fast", "White fashion", "Yellow quiet", "Orange honest", "Pink loud") +def foo3() = println("Here are some colors:") + +@myMain()() +def foo4() = println("This will be printed, but nothing more.") + +object Test: + val allClazzes: Seq[Class[?]] = + LazyList.from(1).map(i => scala.util.Try(Class.forName("foo" + i.toString))).takeWhile(_.isSuccess).map(_.get) + + def callMains(): Unit = + for (clazz <- allClazzes) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) + + def main(args: Array[String]) = + callMains() +end Test + +// This is a toy example, it only works with positional args +@experimental +class myMain(runs: Int = 3)(after: String*) extends MainAnnotation: + import MainAnnotation.* + + def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = + new Command[FromString, Any]: + + override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = + () => p.fromString(args(idx)) + + override def varargGetter[T](using p: FromString[T]): () => Seq[T] = + () => for i <- (info.parameters.length until args.length) yield p.fromString(args(i)) + + override def run(f: () => Any): Unit = + for (_ <- 1 to runs) + f() + if after.length > 0 then println(after.mkString(", ")) + end run + end command +end myMain diff --git a/tests/run/main-annotation-homemade-annot-3.check b/tests/run/main-annotation-homemade-annot-3.check new file mode 100644 index 000000000000..cd0875583aab --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-3.check @@ -0,0 +1 @@ +Hello world! diff --git a/tests/run/main-annotation-homemade-annot-3.scala b/tests/run/main-annotation-homemade-annot-3.scala new file mode 100644 index 000000000000..640f6a934004 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-3.scala @@ -0,0 +1,24 @@ +import scala.annotation.* +import scala.util.CommandLineParser.FromString + +@mainNoArgs def foo() = println("Hello world!") + +object Test: + def main(args: Array[String]) = + val clazz = Class.forName("foo") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) +end Test + +@experimental +class mainNoArgs extends MainAnnotation: + import MainAnnotation.* + + def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = + new Command[FromString, Any]: + override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? + + override def varargGetter[T](using p: FromString[T]): () => Seq[T] = ??? + + override def run(program: () => Any): Unit = program() + end command diff --git a/tests/run/main-annotation-homemade-annot-4.check b/tests/run/main-annotation-homemade-annot-4.check new file mode 100644 index 000000000000..cd0875583aab --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-4.check @@ -0,0 +1 @@ +Hello world! diff --git a/tests/run/main-annotation-homemade-annot-4.scala b/tests/run/main-annotation-homemade-annot-4.scala new file mode 100644 index 000000000000..602744398e74 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-4.scala @@ -0,0 +1,24 @@ +import scala.annotation.* +import scala.util.CommandLineParser.FromString + +@mainManyArgs(1, "B", 3) def foo() = println("Hello world!") + +object Test: + def main(args: Array[String]) = + val clazz = Class.forName("foo") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) +end Test + +@experimental +class mainManyArgs(i1: Int, s2: String, i3: Int) extends MainAnnotation: + import MainAnnotation.* + + def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = + new Command[FromString, Any]: + override def argGetter[T](idx: Int, optDefaultGetter: Option[() => T])(using p: FromString[T]): () => T = ??? + + override def varargGetter[T](using p: FromString[T]): () => Seq[T] = ??? + + override def run(program: () => Any): Unit = program() + end command diff --git a/tests/run/main-annotation-homemade-annot-5.check b/tests/run/main-annotation-homemade-annot-5.check new file mode 100644 index 000000000000..7d60d6656c81 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-5.check @@ -0,0 +1,2 @@ +Hello world! +Hello world! diff --git a/tests/run/main-annotation-homemade-annot-5.scala b/tests/run/main-annotation-homemade-annot-5.scala new file mode 100644 index 000000000000..e529ac304efe --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-5.scala @@ -0,0 +1,26 @@ +import scala.annotation.* +import scala.util.CommandLineParser.FromString + +@mainManyArgs(Some(1)) def foo() = println("Hello world!") +@mainManyArgs(None) def bar() = println("Hello world!") + +object Test: + def main(args: Array[String]) = + for (methodName <- List("foo", "bar")) + val clazz = Class.forName(methodName) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) +end Test + +@experimental +class mainManyArgs(o: Option[Int]) extends MainAnnotation: + import MainAnnotation.* + + def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = + new Command[FromString, Any]: + override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? + + override def varargGetter[T](using p: FromString[T]): () => Seq[T] = ??? + + override def run(program: () => Any): Unit = program() + end command diff --git a/tests/run/main-annotation-homemade-annot-6.check b/tests/run/main-annotation-homemade-annot-6.check new file mode 100644 index 000000000000..b9e33bf3e406 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-6.check @@ -0,0 +1,28 @@ +command( + Array(), + foo, + "Foo docs", + Seq( + ParameterInfo(name="i", typeName="scala.Int", hasDefault=false, isVarargs=false, documentation="", annotations=List()), + ParameterInfo(name="j", typeName="java.lang.String", hasDefault=true, isVarargs=false, documentation="", annotations=List()) + )* +) +argGetter(0, None) +argGetter(1, Some(2)) +run() +foo(42, abc) + +command( + Array(), + bar, + "Bar docs", + Seq( + ParameterInfo(name="i", typeName="scala.collection.immutable.List[Int]", hasDefault=false, isVarargs=false, documentation="the first parameter", annotations=List(MyParamAnnot(3))), + ParameterInfo(name="rest", typeName="scala.Int", hasDefault=false, isVarargs=true, documentation="", annotations=List()) + )* +) +argGetter(0, None) +varargGetter() +run() +bar(List(42), 42, 42) + diff --git a/tests/run/main-annotation-homemade-annot-6.scala b/tests/run/main-annotation-homemade-annot-6.scala new file mode 100644 index 000000000000..5d1c227d0c72 --- /dev/null +++ b/tests/run/main-annotation-homemade-annot-6.scala @@ -0,0 +1,63 @@ +import scala.annotation.* + +/** Foo docs */ +@myMain def foo(i: Int, j: String = "2") = println(s"foo($i, $j)") +/** Bar docs + * + * @param i the first parameter + */ +@myMain def bar(@MyParamAnnot(3) i: List[Int], rest: Int*) = println(s"bar($i, ${rest.mkString(", ")})") + +object Test: + def main(args: Array[String]) = + for (methodName <- List("foo", "bar")) + val clazz = Class.forName(methodName) + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, Array[String]()) +end Test + +@experimental +class myMain extends MainAnnotation: + import MainAnnotation.* + + def command(info: CommandInfo, args: Array[String]): Command[Make, Any] = + def paramInfoString(paramInfo: ParameterInfo) = + import paramInfo.* + s" ParameterInfo(name=\"$name\", typeName=\"$typeName\", hasDefault=$hasDefault, isVarargs=$isVarargs, documentation=\"$documentation\", annotations=$annotations)" + println( + s"""command( + | ${args.mkString("Array(", ", ", ")")}, + | ${info.name}, + | "${info.documentation}", + | ${info.parameters.map(paramInfoString).mkString("Seq(\n", ",\n", "\n )*")} + |)""".stripMargin) + new Command[Make, Any]: + override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: Make[T]): () => T = + println(s"argGetter($idx, ${defaultArgument.map(_())})") + () => p.make + + override def varargGetter[T](using p: Make[T]): () => Seq[T] = + println("varargGetter()") + () => Seq(p.make, p.make) + + override def run(f: () => Any): Unit = + println("run()") + f() + println() + end command + +@experimental +case class MyParamAnnot(n: Int) extends MainAnnotation.ParameterAnnotation + +trait Make[T]: + def make: T + +given Make[Int] with + def make: Int = 42 + + +given Make[String] with + def make: String = "abc" + +given [T: Make]: Make[List[T]] with + def make: List[T] = List(summon[Make[T]].make) diff --git a/tests/run/main-annotation-newMain.scala b/tests/run/main-annotation-newMain.scala new file mode 100644 index 000000000000..c2a538443b24 --- /dev/null +++ b/tests/run/main-annotation-newMain.scala @@ -0,0 +1,307 @@ +import scala.annotation.* +import collection.mutable +import scala.util.CommandLineParser.FromString + +@newMain def happyBirthday(age: Int, name: String, others: String*) = + val suffix = + age % 100 match + case 11 | 12 | 13 => "th" + case _ => + age % 10 match + case 1 => "st" + case 2 => "nd" + case 3 => "rd" + case _ => "th" + val bldr = new StringBuilder(s"Happy $age$suffix birthday, $name") + for other <- others do bldr.append(" and ").append(other) + println(bldr) + + +object Test: + def callMain(args: Array[String]): Unit = + val clazz = Class.forName("happyBirthday") + val method = clazz.getMethod("main", classOf[Array[String]]) + method.invoke(null, args) + + def main(args: Array[String]): Unit = + callMain(Array("23", "Lisa", "Peter")) +end Test + + + +@experimental +final class newMain extends MainAnnotation: + import newMain._ + import MainAnnotation._ + + def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = + new Command[FromString, Any]: + + private inline val argMarker = "--" + private inline val shortArgMarker = "-" + + /** + * The name of the special argument to display the method's help. + * If one of the method's parameters is called the same, will be ignored. + */ + private inline val helpArg = "help" + + /** + * The short name of the special argument to display the method's help. + * If one of the method's parameters uses the same short name, will be ignored. + */ + private inline val shortHelpArg = 'h' + private var shortHelpIsOverridden = false + + private inline val maxUsageLineLength = 120 + + /** A map from argument canonical name (the name of the parameter in the method definition) to parameter informations */ + private val nameToParameterInfo: Map[String, ParameterInfo] = info.parameters.map(infos => infos.name -> infos).toMap + + private val (positionalArgs, byNameArgs, invalidByNameArgs, helpIsOverridden) = { + val namesToCanonicalName: Map[String, String] = info.parameters.flatMap( + infos => + var names = getAlternativeNames(infos) + val canonicalName = infos.name + if nameIsValid(canonicalName) then names = canonicalName +: names + names.map(_ -> canonicalName) + ).toMap + val shortNamesToCanonicalName: Map[Char, String] = info.parameters.flatMap( + infos => + var names = getShortNames(infos) + val canonicalName = infos.name + if shortNameIsValid(canonicalName) then names = canonicalName(0) +: names + names.map(_ -> canonicalName) + ).toMap + + val helpIsOverridden = namesToCanonicalName.exists((name, _) => name == helpArg) + shortHelpIsOverridden = shortNamesToCanonicalName.exists((name, _) => name == shortHelpArg) + + def getCanonicalArgName(arg: String): Option[String] = + if arg.startsWith(argMarker) && arg.length > argMarker.length then + namesToCanonicalName.get(arg.drop(argMarker.length)) + else if arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 then + shortNamesToCanonicalName.get(arg(shortArgMarker.length)) + else + None + + def isArgName(arg: String): Boolean = + val isFullName = arg.startsWith(argMarker) + val isShortName = arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 && shortNameIsValid(arg(shortArgMarker.length)) + isFullName || isShortName + + def recurse(remainingArgs: Seq[String], pa: mutable.Queue[String], bna: Seq[(String, String)], ia: Seq[String]): (mutable.Queue[String], Seq[(String, String)], Seq[String]) = + remainingArgs match { + case Seq() => + (pa, bna, ia) + case argName +: argValue +: rest if isArgName(argName) => + getCanonicalArgName(argName) match { + case Some(canonicalName) => recurse(rest, pa, bna :+ (canonicalName -> argValue), ia) + case None => recurse(rest, pa, bna, ia :+ argName) + } + case arg +: rest => + recurse(rest, pa :+ arg, bna, ia) + } + + val (pa, bna, ia) = recurse(args.toSeq, mutable.Queue.empty, Vector(), Vector()) + val nameToArgValues: Map[String, Seq[String]] = if bna.isEmpty then Map.empty else bna.groupMapReduce(_._1)(p => List(p._2))(_ ++ _) + (pa, nameToArgValues, ia, helpIsOverridden) + } + + /** A buffer for all errors */ + private val errors = new mutable.ArrayBuffer[String] + + /** Issue an error, and return an uncallable getter */ + private def error(msg: String): () => Nothing = + errors += msg + () => throw new AssertionError("trying to get invalid argument") + + private inline def nameIsValid(name: String): Boolean = + name.length > 1 // TODO add more checks for illegal characters + + private inline def shortNameIsValid(name: String): Boolean = + name.length == 1 && shortNameIsValid(name(0)) + + private inline def shortNameIsValid(shortName: Char): Boolean = + ('A' <= shortName && shortName <= 'Z') || ('a' <= shortName && shortName <= 'z') + + private def getNameWithMarker(name: String | Char): String = name match { + case c: Char => shortArgMarker + c + case s: String if shortNameIsValid(s) => shortArgMarker + s + case s => argMarker + s + } + + private def convert[T](argName: String, arg: String)(using p: FromString[T]): () => T = + p.fromStringOption(arg) match + case Some(t) => () => t + case None => error(s"invalid argument for $argName: $arg") + + private def usage(): Unit = + def argsUsage: Seq[String] = + for info <- info.parameters yield + val canonicalName = getNameWithMarker(info.name) + val shortNames = getShortNames(info).map(getNameWithMarker) + val alternativeNames = getAlternativeNames(info).map(getNameWithMarker) + val namesPrint = (canonicalName +: alternativeNames ++: shortNames).mkString("[", " | ", "]") + if info.isVarargs then s"[<${info.typeName}> [<${info.typeName}> [...]]]" + else if info.hasDefault then s"[$namesPrint <${info.typeName}>]" + else s"$namesPrint <${info.typeName}>" + end for + + def wrapArgumentUsages(argsUsage: Seq[String], maxLength: Int): Seq[String] = { + def recurse(args: Seq[String], currentLine: String, acc: Vector[String]): Seq[String] = + (args, currentLine) match { + case (Nil, "") => acc + case (Nil, l) => (acc :+ l) + case (arg +: t, "") => recurse(t, arg, acc) + case (arg +: t, l) if l.length + 1 + arg.length <= maxLength => recurse(t, s"$l $arg", acc) + case (arg +: t, l) => recurse(t, arg, acc :+ l) + } + + recurse(argsUsage, "", Vector()).toList + } + + val usageBeginning = s"Usage: ${info.name} " + val argsOffset = usageBeginning.length + val usages = wrapArgumentUsages(argsUsage, maxUsageLineLength - argsOffset) + + println(usageBeginning + usages.mkString("\n" + " " * argsOffset)) + end usage + + private def explain(): Unit = + inline def shiftLines(s: Seq[String], shift: Int): String = s.map(" " * shift + _).mkString("\n") + + def wrapLongLine(line: String, maxLength: Int): List[String] = { + def recurse(s: String, acc: Vector[String]): Seq[String] = + val lastSpace = s.trim.nn.lastIndexOf(' ', maxLength) + if ((s.length <= maxLength) || (lastSpace < 0)) + acc :+ s + else { + val (shortLine, rest) = s.splitAt(lastSpace) + recurse(rest.trim.nn, acc :+ shortLine) + } + + recurse(line, Vector()).toList + } + + if (info.documentation.nonEmpty) + println(wrapLongLine(info.documentation, maxUsageLineLength).mkString("\n")) + if (nameToParameterInfo.nonEmpty) { + val argNameShift = 2 + val argDocShift = argNameShift + 2 + + println("Arguments:") + for info <- info.parameters do + val canonicalName = getNameWithMarker(info.name) + val shortNames = getShortNames(info).map(getNameWithMarker) + val alternativeNames = getAlternativeNames(info).map(getNameWithMarker) + val otherNames = (alternativeNames ++: shortNames) match { + case Seq() => "" + case names => names.mkString("(", ", ", ") ") + } + val argDoc = StringBuilder(" " * argNameShift) + argDoc.append(s"$canonicalName $otherNames- ${info.typeName}") + + if info.isVarargs then argDoc.append(" (vararg)") + else if info.hasDefault then argDoc.append(" (optional)") + + val doc = info.documentation + if (doc.nonEmpty) { + val shiftedDoc = + doc.split("\n").nn + .map(line => shiftLines(wrapLongLine(line.nn, maxUsageLineLength - argDocShift), argDocShift)) + .mkString("\n") + argDoc.append("\n").append(shiftedDoc) + } + + println(argDoc) + end for + } + end explain + + private def getAliases(paramInfos: ParameterInfo): Seq[String] = + paramInfos.annotations.collect{ case a: Alias => a }.flatMap(_.aliases) + + private def getAlternativeNames(paramInfos: ParameterInfo): Seq[String] = + getAliases(paramInfos).filter(nameIsValid(_)) + + private def getShortNames(paramInfos: ParameterInfo): Seq[Char] = + getAliases(paramInfos).filter(shortNameIsValid(_)).map(_(0)) + + private def getInvalidNames(paramInfos: ParameterInfo): Seq[String | Char] = + getAliases(paramInfos).filter(name => !nameIsValid(name) && !shortNameIsValid(name)) + + override def argGetter[T](idx: Int, optDefaultGetter: Option[() => T])(using p: FromString[T]): () => T = + val name = info.parameters(idx).name + val parameterInfo = nameToParameterInfo(name) + // TODO: Decide which string is associated with this arg when constructing the command. + // Here we should only get the string for this argument, apply it to the parser and handle parsing errors. + // Should be able to get the argument from its index. + byNameArgs.get(name) match { + case Some(Nil) => + throw AssertionError(s"$name present in byNameArgs, but it has no argument value") + case Some(argValues) => + if argValues.length > 1 then + // Do not accept multiple values + // Remove this test to take last given argument + error(s"more than one value for $name: ${argValues.mkString(", ")}") + else + convert(name, argValues.last) + case None => + if positionalArgs.length > 0 then + convert(name, positionalArgs.dequeue) + else if optDefaultGetter.nonEmpty then + optDefaultGetter.get + else + error(s"missing argument for $name") + } + end argGetter + + override def varargGetter[T](using p: FromString[T]): () => Seq[T] = + val name = info.parameters.last.name + // TODO: Decide which strings are associated with the varargs when constructing the command. + // Here we should only get the strings for this argument, apply them to the parser and handle parsing errors. + // Should be able to get the argument from its index (last). + val byNameGetters = byNameArgs.getOrElse(name, Seq()).map(arg => convert(name, arg)) + val positionalGetters = positionalArgs.removeAll.map(arg => convert(name, arg)) + // First take arguments passed by name, then those passed by position + () => (byNameGetters ++ positionalGetters).map(_()) + + override def run(f: () => Any): Unit = + // Check aliases unicity + val nameAndCanonicalName = nameToParameterInfo.toList.flatMap { + case (canonicalName, infos) => (canonicalName +: getAlternativeNames(infos) ++: getShortNames(infos)).map(_ -> canonicalName) + } + val nameToCanonicalNames = nameAndCanonicalName.groupMap(_._1)(_._2) + + for (name, canonicalNames) <- nameToCanonicalNames if canonicalNames.length > 1 + do throw IllegalArgumentException(s"$name is used for multiple parameters: ${canonicalNames.mkString(", ")}") + + // Check aliases validity + val problematicNames = nameToParameterInfo.toList.flatMap((_, infos) => getInvalidNames(infos)) + if problematicNames.length > 0 then throw IllegalArgumentException(s"The following aliases are invalid: ${problematicNames.mkString(", ")}") + + // Handle unused and invalid args + for (remainingArg <- positionalArgs) error(s"unused argument: $remainingArg") + for (invalidArg <- invalidByNameArgs) error(s"unknown argument name: $invalidArg") + + val displayHelp = + (!helpIsOverridden && args.contains(getNameWithMarker(helpArg))) || (!shortHelpIsOverridden && args.contains(getNameWithMarker(shortHelpArg))) + + if displayHelp then + usage() + println() + explain() + else if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + usage() + else + f() + end run + end command +end newMain + +object newMain: + @experimental + final class Alias(val aliases: String*) extends MainAnnotation.ParameterAnnotation +end newMain From 813e059afe04428089a5b6e79b7ec956c0478741 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Tue, 5 Apr 2022 15:48:15 +0200 Subject: [PATCH 2/2] Simplify `MainAnnotation` interface Remove the `Command` class and place the `argGetter`, `varargsGetter` and `run` methods directly in the `MainAnnotation` interface. Now `command` pre-processes the arguments which clearly states which strings will be used for each argument. This simplifies the implementation of the `MainAnnotation` methods. --- .../dotty/tools/dotc/ast/MainProxies.scala | 104 ++-- .../dotty/tools/dotc/core/Definitions.scala | 4 +- .../src/dotty/tools/dotc/core/StdNames.scala | 1 + .../reference/experimental/main-annotation.md | 92 ++-- .../src/scala/annotation/MainAnnotation.scala | 99 ++-- .../reference-expected-links.txt | 1 + tests/run/main-annotation-example.scala | 65 +-- .../main-annotation-homemade-annot-1.scala | 17 +- .../main-annotation-homemade-annot-2.scala | 25 +- .../main-annotation-homemade-annot-3.scala | 13 +- .../main-annotation-homemade-annot-4.scala | 14 +- .../main-annotation-homemade-annot-5.scala | 13 +- .../main-annotation-homemade-annot-6.check | 15 +- .../main-annotation-homemade-annot-6.scala | 37 +- tests/run/main-annotation-newMain.scala | 473 +++++++++--------- 15 files changed, 511 insertions(+), 462 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala index 01ae50850b57..5e969c0c38c9 100644 --- a/compiler/src/dotty/tools/dotc/ast/MainProxies.scala +++ b/compiler/src/dotty/tools/dotc/ast/MainProxies.scala @@ -143,6 +143,7 @@ object MainProxies { * */ * @myMain(80) def f( * @myMain.Alias("myX") x: S, + * y: S, * ys: T* * ) = ... * @@ -150,22 +151,23 @@ object MainProxies { * * final class f { * static def main(args: Array[String]): Unit = { - * val cmd = new myMain(80).command( - * info = new CommandInfo( - * name = "f", - * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.", - * parameters = Seq( - * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))) - * new scala.annotation.MainAnnotation.ParameterInfo("ys", "T", false, false, "all my params y", Seq()) - * ) + * val annotation = new myMain(80) + * val info = new Info( + * name = "f", + * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.", + * parameters = Seq( + * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))), + * new scala.annotation.MainAnnotation.Parameter("y", "S", true, false, "", Seq()), + * new scala.annotation.MainAnnotation.Parameter("ys", "T", false, true, "all my params y", Seq()) * ) - * args = args - * ) - * - * val args0: () => S = cmd.argGetter[S](0, None) - * val args1: () => Seq[T] = cmd.varargGetter[T] - * - * cmd.run(() => f(args0(), args1()*)) + * ), + * val command = annotation.command(info, args) + * if command.isDefined then + * val cmd = command.get + * val args0: () => S = annotation.argGetter[S](info.parameters(0), cmd(0), None) + * val args1: () => S = annotation.argGetter[S](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) + * val args2: () => Seq[T] = annotation.varargGetter[T](info.parameters(2), cmd.drop(2)) + * annotation.run(() => f(args0(), args1(), args2()*)) * } * } */ @@ -229,7 +231,7 @@ object MainProxies { * * A ParamInfo has the following shape * ``` - * new scala.annotation.MainAnnotation.ParameterInfo("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))) + * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))) * ``` */ def parameterInfos(mt: MethodType): List[Tree] = @@ -252,33 +254,34 @@ object MainProxies { val constructorArgs = List(param, paramTypeStr, hasDefault, isRepeated, paramDoc) .map(value => Literal(Constant(value))) - New(TypeTree(defn.MainAnnotationParameterInfo.typeRef), List(constructorArgs :+ paramAnnots)) + New(TypeTree(defn.MainAnnotationParameter.typeRef), List(constructorArgs :+ paramAnnots)) end parameterInfos /** * Creates a list of references and definitions of arguments. * The goal is to create the - * `val args0: () => S = cmd.argGetter[S](0, None)` + * `val args0: () => S = annotation.argGetter[S](0, cmd(0), None)` * part of the code. */ def argValDefs(mt: MethodType): List[ValDef] = for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield - val argName = nme.args ++ idx.toString - val isRepeated = formal.isRepeatedParam - val formalType = if isRepeated then formal.argTypes.head else formal - val getterName = if isRepeated then nme.varargGetter else nme.argGetter - val defaultValueGetterOpt = defaultValueSymbols.get(idx) match - case None => ref(defn.NoneModule.termRef) - case Some(dvSym) => - val value = unitToValue(ref(dvSym.termRef)) - Apply(ref(defn.SomeClass.companionModule.termRef), value) - val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterName), TypeTree(formalType) :: Nil) - val argGetter = - if isRepeated then argGetter0 - else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt)) - - ValDef(argName, TypeTree(), argGetter) + val argName = nme.args ++ idx.toString + val isRepeated = formal.isRepeatedParam + val formalType = if isRepeated then formal.argTypes.head else formal + val getterName = if isRepeated then nme.varargGetter else nme.argGetter + val defaultValueGetterOpt = defaultValueSymbols.get(idx) match + case None => ref(defn.NoneModule.termRef) + case Some(dvSym) => + val value = unitToValue(ref(dvSym.termRef)) + Apply(ref(defn.SomeClass.companionModule.termRef), value) + val argGetter0 = TypeApply(Select(Ident(nme.annotation), getterName), TypeTree(formalType) :: Nil) + val index = Literal(Constant(idx)) + val paramInfo = Apply(Select(Ident(nme.info), nme.parameters), index) + val argGetter = + if isRepeated then Apply(argGetter0, List(paramInfo, Apply(Select(Ident(nme.cmd), nme.drop), List(index)))) + else Apply(argGetter0, List(paramInfo, Apply(Ident(nme.cmd), List(index)), defaultValueGetterOpt)) + ValDef(argName, TypeTree(), argGetter) end argValDefs @@ -318,18 +321,39 @@ object MainProxies { val nameTree = Literal(Constant(mainFun.showName)) val docTree = Literal(Constant(documentation.mainDoc)) val paramInfos = Apply(ref(defn.SeqModule.termRef), parameterInfos) - New(TypeTree(defn.MainAnnotationCommandInfo.typeRef), List(List(nameTree, docTree, paramInfos))) + New(TypeTree(defn.MainAnnotationInfo.typeRef), List(List(nameTree, docTree, paramInfos))) - val cmd = ValDef( - nme.cmd, + val annotVal = ValDef( + nme.annotation, + TypeTree(), + instantiateAnnotation(mainAnnot) + ) + val infoVal = ValDef( + nme.info, + TypeTree(), + cmdInfo + ) + val command = ValDef( + nme.command, TypeTree(), Apply( - Select(instantiateAnnotation(mainAnnot), nme.command), - List(cmdInfo, Ident(nme.args)) + Select(Ident(nme.annotation), nme.command), + List(Ident(nme.info), Ident(nme.args)) ) ) - val run = Apply(Select(Ident(nme.cmd), nme.run), mainCall) - val body = Block(cmdInfo :: cmd :: args, run) + val argsVal = ValDef( + nme.cmd, + TypeTree(), + Select(Ident(nme.command), nme.get) + ) + val run = Apply(Select(Ident(nme.annotation), nme.run), mainCall) + val body0 = If( + Select(Ident(nme.command), nme.isDefined), + Block(argsVal :: args, run), + EmptyTree + ) + val body = Block(List(annotVal, infoVal, command), body0) // TODO add `if (cmd.nonEmpty)` + val mainArg = ValDef(nme.args, TypeTree(defn.ArrayType.appliedTo(defn.StringType)), EmptyTree) .withFlags(Param) /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol. diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 9e70102fa1dd..8aaaff52708d 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -856,8 +856,8 @@ class Definitions { @tu lazy val XMLTopScopeModule: Symbol = requiredModule("scala.xml.TopScope") @tu lazy val MainAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MainAnnotation") - @tu lazy val MainAnnotationCommandInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.CommandInfo") - @tu lazy val MainAnnotationParameterInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterInfo") + @tu lazy val MainAnnotationInfo: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Info") + @tu lazy val MainAnnotationParameter: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Parameter") @tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation") @tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index bb5efcc01a5e..dc9e48b65f47 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -543,6 +543,7 @@ object StdNames { val ordinalDollar: N = "$ordinal" val ordinalDollar_ : N = "_$ordinal" val origin: N = "origin" + val parameters: N = "parameters" val parts: N = "parts" val postfixOps: N = "postfixOps" val prefix : N = "prefix" diff --git a/docs/_docs/reference/experimental/main-annotation.md b/docs/_docs/reference/experimental/main-annotation.md index c87d143ed15f..d2172d97a284 100644 --- a/docs/_docs/reference/experimental/main-annotation.md +++ b/docs/_docs/reference/experimental/main-annotation.md @@ -13,33 +13,35 @@ When a users annotates a method with an annotation that extends `MainAnnotation` * @param first Fist number to sum * @param rest The rest of the numbers to sum */ -@myMain def sum(first: Int, rest: Int*): Int = first + rest.sum +@myMain def sum(first: Int, second: Int = 0, rest: Int*): Int = first + second + rest.sum ``` ```scala object foo { def main(args: Array[String]): Unit = { - - val cmd = new myMain().command( - info = new CommandInfo( - name = "sum", - documentation = "Sum all the numbers", - parameters = Seq( - new ParameterInfo("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()), - new ParameterInfo("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq()) - ) - ), - args = args + val mainAnnot = new myMain() + val info = new Info( + name = "foo.main", + documentation = "Sum all the numbers", + parameters = Seq( + new Parameter("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum", Seq()), + new Parameter("second", "scala.Int", hasDefault=true, isVarargs=false, "", Seq()), + new Parameter("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum", Seq()) + ) ) - val args0 = cmd.argGetter[Int](0, None) // using a parser of Int - val args1 = cmd.varargGetter[Int] // using a parser of Int - cmd.run(() => sum(args0(), args1()*)) + val mainArgsOpt = mainAnnot.command(info, args) + if mainArgsOpt.isDefined then + val mainArgs = mainArgsOpt.get + val args0 = mainAnnot.argGetter[Int](info.parameters(0), mainArgs(0), None) // using a parser of Int + val args1 = mainAnnot.argGetter[Int](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) // using a parser of Int + val args2 = mainAnnot.varargGetter[Int](info.parameters(2), mainArgs.drop(2)) // using a parser of Int + mainAnnot.run(() => sum(args0(), args1(), args2()*)) } } ``` -The implementation of the `main` method first instantiates the annotation and then creates a `Command`. -When creating the `Command`, the arguments can be checked and preprocessed. +The implementation of the `main` method first instantiates the annotation and then call `command`. +When calling the `command`, the arguments can be checked and preprocessed. Then it defines a series of argument getters calling `argGetter` for each parameter and `varargGetter` for the last one if it is a varargs. `argGetter` gets an optional lambda that computes the default argument. Finally, the `run` method is called to run the application. It receives a by-name argument that contains the call the annotated method with the instantiations arguments (using the lambdas from `argGetter`/`varargGetter`). @@ -50,42 +52,46 @@ Example of implementation of `myMain` that takes all arguments positionally. It // Parser used to parse command line arguments import scala.util.CommandLineParser.FromString[T] -// Result type of the annotated method is Int -class myMain extends MainAnnotation: - import MainAnnotation.{ ParameterInfo, Command } +// Result type of the annotated method is Int and arguments are parsed using FromString +@experimental class myMain extends MainAnnotation[FromString, Int]: + import MainAnnotation.{ Info, Parameter } - /** A new command with arguments from `args` */ - def command(info: CommandInfo, args: Array[String]): Command[FromString, Int] = + def command(info: Info, args: Seq[String]): Option[Seq[String]] = if args.contains("--help") then println(info.documentation) - // TODO: Print documentation of the parameters - System.exit(0) - assert(info.parameters.forall(!_.hasDefault), "Default arguments are not supported") - val (plainArgs, varargs) = - if info.parameters.last.isVarargs then - val numPlainArgs = info.parameters.length - 1 - assert(numPlainArgs <= args.length, "Not enough arguments") - (args.take(numPlainArgs), args.drop(numPlainArgs)) + None // do not parse or run the program + else if info.parameters.exists(_.hasDefault) then + println("Default arguments are not supported") + None + else if info.hasVarargs then + val numPlainArgs = info.parameters.length - 1 + if numPlainArgs <= args.length then + println("Not enough arguments") + None + else + Some(args) + else + if info.parameters.length <= args.length then + println("Not enough arguments") + None + else if info.parameters.length >= args.length then + println("Too many arguments") + None else - assert(info.parameters.length <= args.length, "Not enough arguments") - assert(info.parameters.length >= args.length, "Too many arguments") - (args, Array.empty[String]) - new MyCommand(plainArgs, varargs) + Some(args) - @experimental - class MyCommand(plainArgs: Seq[String], varargs: Seq[String]) extends Command[FromString, Int]: + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = + () => parser.fromString(arg) - def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = - () => parser.fromString(plainArgs(idx)) + def varargGetter[T](param: Parameter, args: Seq[String])(using parser: FromString[T]): () => Seq[T] = + () => args.map(arg => parser.fromString(arg)) - def varargGetter[T](using parser: FromString[T]): () => Seq[T] = - () => varargs.map(arg => parser.fromString(arg)) + def run(program: () => Int): Unit = + println("executing program") - def run(program: () => Int): Unit = - println("executing program") + try { val result = program() println("result: " + result) println("executed program") - end MyCommand end myMain ``` diff --git a/library/src/scala/annotation/MainAnnotation.scala b/library/src/scala/annotation/MainAnnotation.scala index 6e30ee6f69a3..9d2f5362ba15 100644 --- a/library/src/scala/annotation/MainAnnotation.scala +++ b/library/src/scala/annotation/MainAnnotation.scala @@ -17,83 +17,90 @@ package scala.annotation * * @param first Fist number to sum * * @param rest The rest of the numbers to sum * */ - * @myMain def sum(first: Int, rest: Int*): Int = first + rest.sum + * @myMain def sum(first: Int, second: Int = 0, rest: Int*): Int = first + second + rest.sum * ``` * generates * ```scala * object foo { * def main(args: Array[String]): Unit = { - * val cmd = new myMain().command( - * info = new CommandInfo( - * name = "foo.main", - * documentation = "Sum all the numbers", - * parameters = Seq( - * new ParameterInfo("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum"), - * new ParameterInfo("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum") - * ) + * val mainAnnot = new myMain() + * val info = new Info( + * name = "foo.main", + * documentation = "Sum all the numbers", + * parameters = Seq( + * new Parameter("first", "scala.Int", hasDefault=false, isVarargs=false, "Fist number to sum"), + * new Parameter("rest", "scala.Int" , hasDefault=false, isVarargs=true, "The rest of the numbers to sum") * ) - * args = args * ) - * val args0 = cmd.argGetter[Int](0, None) // using cmd.Parser[Int] - * val args1 = cmd.varargGetter[Int] // using cmd.Parser[Int] - * cmd.run(() => sum(args0(), args1()*)) + * val mainArgsOpt = mainAnnot.command(info, args) + * if mainArgsOpt.isDefined then + * val mainArgs = mainArgsOpt.get + * val args0 = mainAnnot.argGetter[Int](info.parameters(0), mainArgs(0), None) // using parser Int + * val args1 = mainAnnot.argGetter[Int](info.parameters(1), mainArgs(1), Some(() => sum$default$1())) // using parser Int + * val args2 = mainAnnot.varargGetter[Int](info.parameters(2), mainArgs.drop(2)) // using parser Int + * mainAnnot.run(() => sum(args0(), args1(), args2()*)) * } * } * ``` * + * @param Parser The class used for argument string parsing and arguments into a `T` + * @param Result The required result type of the main method. + * If this type is Any or Unit, any type will be accepted. */ @experimental -trait MainAnnotation extends StaticAnnotation: - import MainAnnotation.{Command, CommandInfo} +trait MainAnnotation[Parser[_], Result] extends StaticAnnotation: + import MainAnnotation.{Info, Parameter} - /** A new command with arguments from `args` + /** Process the command arguments before parsing them. + * + * Return `Some` of the sequence of arguments that will be parsed to be passed to the main method. + * This sequence needs to have the same length as the number of parameters of the main method (i.e. `info.parameters.size`). + * If there is a varags parameter, then the sequence must be at least of length `info.parameters.size - 1`. + * + * Returns `None` if the arguments are invalid and parsing and run should be stopped. * * @param info The information about the command (name, documentation and info about parameters) * @param args The command line arguments */ - def command(info: CommandInfo, args: Array[String]): Command[?, ?] + def command(info: Info, args: Seq[String]): Option[Seq[String]] -end MainAnnotation + /** The getter for the `idx`th argument of type `T` + * + * @param idx The index of the argument + * @param defaultArgument Optional lambda to instantiate the default argument + */ + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using Parser[T]): () => T -@experimental -object MainAnnotation: + /** The getter for a final varargs argument of type `T*` */ + def varargGetter[T](param: Parameter, args: Seq[String])(using Parser[T]): () => Seq[T] - /** A class representing a command to run + /** Run `program` if all arguments are valid if all arguments are valid * - * @param Parser The class used for argument string parsing and arguments into a `T` - * @param Result The required result type of the main method. - * If this type is Any or Unit, any type will be accepted. + * @param program A function containing the call to the main method and instantiation of its arguments */ - trait Command[Parser[_], Result]: + def run(program: () => Result): Unit - /** The getter for the `idx`th argument of type `T` - * - * @param idx The index of the argument - * @param defaultArgument Optional lambda to instantiate the default argument - */ - def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using Parser[T]): () => T - - /** The getter for a final varargs argument of type `T*` */ - def varargGetter[T](using Parser[T]): () => Seq[T] +end MainAnnotation - /** Run `program` if all arguments are valid if all arguments are valid - * - * @param program A function containing the call to the main method and instantiation of its arguments - */ - def run(program: () => Result): Unit - end Command +@experimental +object MainAnnotation: /** Information about the main method * * @param name The name of the main method - * @param documentation The documentation of the main method without the `@param` documentation (see ParameterInfo.documentaion) + * @param documentation The documentation of the main method without the `@param` documentation (see Parameter.documentaion) * @param parameters Information about the parameters of the main method */ - final class CommandInfo( + final class Info( val name: String, val documentation: String, - val parameters: Seq[ParameterInfo], - ) + val parameters: Seq[Parameter], + ): + + /** If the method ends with a varargs parameter */ + def hasVarargs: Boolean = parameters.nonEmpty && parameters.last.isVarargs + + end Info /** Information about a parameter of a main method * @@ -104,7 +111,7 @@ object MainAnnotation: * @param documentation The documentation of the parameter (from `@param` documentation in the main method) * @param annotations The annotations of the parameter that extend `ParameterAnnotation` */ - final class ParameterInfo ( + final class Parameter( val name: String, val typeName: String, val hasDefault: Boolean, @@ -113,7 +120,7 @@ object MainAnnotation: val annotations: Seq[ParameterAnnotation], ) - /** Marker trait for annotations that will be included in the ParameterInfo annotations. */ + /** Marker trait for annotations that will be included in the Parameter annotations. */ trait ParameterAnnotation extends StaticAnnotation end MainAnnotation diff --git a/project/scripts/expected-links/reference-expected-links.txt b/project/scripts/expected-links/reference-expected-links.txt index 737267576c6e..f51727b7b432 100644 --- a/project/scripts/expected-links/reference-expected-links.txt +++ b/project/scripts/expected-links/reference-expected-links.txt @@ -68,6 +68,7 @@ ./experimental/erased-defs.html ./experimental/explicit-nulls.html ./experimental/index.html +./experimental/main-annotation.html ./experimental/named-typeargs-spec.html ./experimental/named-typeargs.html ./experimental/numeric-literals.html diff --git a/tests/run/main-annotation-example.scala b/tests/run/main-annotation-example.scala index 91036df44f57..954278d6b26f 100644 --- a/tests/run/main-annotation-example.scala +++ b/tests/run/main-annotation-example.scala @@ -21,39 +21,42 @@ object Test: end Test @experimental -class myMain extends MainAnnotation: - import MainAnnotation.{ Command, CommandInfo, ParameterInfo } +class myMain extends MainAnnotation[FromString, Int]: + import MainAnnotation.{ Info, Parameter } - /** A new command with arguments from `args` */ - def command(info: CommandInfo, args: Array[String]): Command[FromString, Int] = + def command(info: Info, args: Seq[String]): Option[Seq[String]] = if args.contains("--help") then println(info.documentation) - System.exit(0) - assert(info.parameters.forall(!_.hasDefault), "Default arguments are not supported") - val (plainArgs, varargs) = - if info.parameters.last.isVarargs then - val numPlainArgs = info.parameters.length - 1 - assert(numPlainArgs <= args.length, "Not enough arguments") - (args.take(numPlainArgs), args.drop(numPlainArgs)) + None // do not parse or run the program + else if info.parameters.exists(_.hasDefault) then + println("Default arguments are not supported") + None + else if info.hasVarargs then + val numPlainArgs = info.parameters.length - 1 + if numPlainArgs > args.length then + println("Not enough arguments") + None else - assert(info.parameters.length <= args.length, "Not enough arguments") - assert(info.parameters.length >= args.length, "Too many arguments") - (args, Array.empty[String]) - new MyCommand(plainArgs, varargs) - - @experimental - class MyCommand(plainArgs: Seq[String], varargs: Seq[String]) extends Command[FromString, Int]: - - def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = - () => parser.fromString(plainArgs(idx)) - - def varargGetter[T](using parser: FromString[T]): () => Seq[T] = - () => varargs.map(arg => parser.fromString(arg)) - - def run(program: () => Int): Unit = - println("executing program") - val result = program() - println("result: " + result) - println("executed program") - end MyCommand + Some(args) + else + if info.parameters.length > args.length then + println("Not enough arguments") + None + else if info.parameters.length < args.length then + println("Too many arguments") + None + else + Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using parser: FromString[T]): () => T = + () => parser.fromString(arg) + + def varargGetter[T](param: Parameter, args: Seq[String])(using parser: FromString[T]): () => Seq[T] = + () => args.map(arg => parser.fromString(arg)) + + def run(program: () => Int): Unit = + println("executing program") + val result = program() + println("result: " + result) + println("executed program") end myMain diff --git a/tests/run/main-annotation-homemade-annot-1.scala b/tests/run/main-annotation-homemade-annot-1.scala index fabbc6348221..daf27b944d99 100644 --- a/tests/run/main-annotation-homemade-annot-1.scala +++ b/tests/run/main-annotation-homemade-annot-1.scala @@ -29,17 +29,18 @@ object Test: end Test @experimental -class mainAwait(timeout: Int = 2) extends MainAnnotation: +class mainAwait(timeout: Int = 2) extends MainAnnotation[FromString, Future[Any]]: import MainAnnotation.* // This is a toy example, it only works with positional args - def command(info: CommandInfo, args: Array[String]): Command[FromString, Future[Any]] = - new Command[FromString, Future[Any]]: - override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = - () => p.fromString(args(idx)) + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) - override def varargGetter[T](using p: FromString[T]): () => Seq[T] = - () => for i <- ((info.parameters.length-1) until args.length) yield p.fromString(args(i)) + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = + () => p.fromString(arg) + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = + () => for arg <- args yield p.fromString(arg) + + def run(f: () => Future[Any]): Unit = println(Await.result(f(), Duration(timeout, SECONDS))) - override def run(f: () => Future[Any]): Unit = println(Await.result(f(), Duration(timeout, SECONDS))) end mainAwait diff --git a/tests/run/main-annotation-homemade-annot-2.scala b/tests/run/main-annotation-homemade-annot-2.scala index e2eecfbd6fcc..3cee9151282d 100644 --- a/tests/run/main-annotation-homemade-annot-2.scala +++ b/tests/run/main-annotation-homemade-annot-2.scala @@ -29,22 +29,21 @@ end Test // This is a toy example, it only works with positional args @experimental -class myMain(runs: Int = 3)(after: String*) extends MainAnnotation: +class myMain(runs: Int = 3)(after: String*) extends MainAnnotation[FromString, Any]: import MainAnnotation.* - def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = - new Command[FromString, Any]: + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) - override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = - () => p.fromString(args(idx)) + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = + () => p.fromString(arg) - override def varargGetter[T](using p: FromString[T]): () => Seq[T] = - () => for i <- (info.parameters.length until args.length) yield p.fromString(args(i)) + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = + () => for arg <- args yield p.fromString(arg) + + def run(f: () => Any): Unit = + for (_ <- 1 to runs) + f() + if after.length > 0 then println(after.mkString(", ")) + end run - override def run(f: () => Any): Unit = - for (_ <- 1 to runs) - f() - if after.length > 0 then println(after.mkString(", ")) - end run - end command end myMain diff --git a/tests/run/main-annotation-homemade-annot-3.scala b/tests/run/main-annotation-homemade-annot-3.scala index 640f6a934004..3fc42abcce79 100644 --- a/tests/run/main-annotation-homemade-annot-3.scala +++ b/tests/run/main-annotation-homemade-annot-3.scala @@ -11,14 +11,13 @@ object Test: end Test @experimental -class mainNoArgs extends MainAnnotation: +class mainNoArgs extends MainAnnotation[FromString, Any]: import MainAnnotation.* - def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = - new Command[FromString, Any]: - override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) - override def varargGetter[T](using p: FromString[T]): () => Seq[T] = ??? + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? - override def run(program: () => Any): Unit = program() - end command + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = ??? + + def run(program: () => Any): Unit = program() diff --git a/tests/run/main-annotation-homemade-annot-4.scala b/tests/run/main-annotation-homemade-annot-4.scala index 602744398e74..0dbd006ee5b1 100644 --- a/tests/run/main-annotation-homemade-annot-4.scala +++ b/tests/run/main-annotation-homemade-annot-4.scala @@ -11,14 +11,14 @@ object Test: end Test @experimental -class mainManyArgs(i1: Int, s2: String, i3: Int) extends MainAnnotation: +class mainManyArgs(i1: Int, s2: String, i3: Int) extends MainAnnotation[FromString, Any]: import MainAnnotation.* - def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = - new Command[FromString, Any]: - override def argGetter[T](idx: Int, optDefaultGetter: Option[() => T])(using p: FromString[T]): () => T = ??? + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) - override def varargGetter[T](using p: FromString[T]): () => Seq[T] = ??? + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? - override def run(program: () => Any): Unit = program() - end command + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = ??? + + + def run(program: () => Any): Unit = program() diff --git a/tests/run/main-annotation-homemade-annot-5.scala b/tests/run/main-annotation-homemade-annot-5.scala index e529ac304efe..d61cd55eb852 100644 --- a/tests/run/main-annotation-homemade-annot-5.scala +++ b/tests/run/main-annotation-homemade-annot-5.scala @@ -13,14 +13,13 @@ object Test: end Test @experimental -class mainManyArgs(o: Option[Int]) extends MainAnnotation: +class mainManyArgs(o: Option[Int]) extends MainAnnotation[FromString, Any]: import MainAnnotation.* - def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = - new Command[FromString, Any]: - override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? + def command(info: Info, args: Seq[String]): Option[Seq[String]] = Some(args) - override def varargGetter[T](using p: FromString[T]): () => Seq[T] = ??? + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = ??? - override def run(program: () => Any): Unit = program() - end command + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = ??? + + def run(program: () => Any): Unit = program() diff --git a/tests/run/main-annotation-homemade-annot-6.check b/tests/run/main-annotation-homemade-annot-6.check index b9e33bf3e406..5cc6c07e1f56 100644 --- a/tests/run/main-annotation-homemade-annot-6.check +++ b/tests/run/main-annotation-homemade-annot-6.check @@ -1,27 +1,24 @@ command( - Array(), + Array(1, 2), foo, "Foo docs", Seq( - ParameterInfo(name="i", typeName="scala.Int", hasDefault=false, isVarargs=false, documentation="", annotations=List()), - ParameterInfo(name="j", typeName="java.lang.String", hasDefault=true, isVarargs=false, documentation="", annotations=List()) + Parameter(name="i", typeName="scala.Int", hasDefault=false, isVarargs=false, documentation="", annotations=List()), + Parameter(name="j", typeName="java.lang.String", hasDefault=true, isVarargs=false, documentation="", annotations=List()) )* ) -argGetter(0, None) -argGetter(1, Some(2)) run() foo(42, abc) command( - Array(), + Array(1, 2), bar, "Bar docs", Seq( - ParameterInfo(name="i", typeName="scala.collection.immutable.List[Int]", hasDefault=false, isVarargs=false, documentation="the first parameter", annotations=List(MyParamAnnot(3))), - ParameterInfo(name="rest", typeName="scala.Int", hasDefault=false, isVarargs=true, documentation="", annotations=List()) + Parameter(name="i", typeName="scala.collection.immutable.List[Int]", hasDefault=false, isVarargs=false, documentation="the first parameter", annotations=List(MyParamAnnot(3))), + Parameter(name="rest", typeName="scala.Int", hasDefault=false, isVarargs=true, documentation="", annotations=List()) )* ) -argGetter(0, None) varargGetter() run() bar(List(42), 42, 42) diff --git a/tests/run/main-annotation-homemade-annot-6.scala b/tests/run/main-annotation-homemade-annot-6.scala index 5d1c227d0c72..9ba0b31fc689 100644 --- a/tests/run/main-annotation-homemade-annot-6.scala +++ b/tests/run/main-annotation-homemade-annot-6.scala @@ -13,17 +13,17 @@ object Test: for (methodName <- List("foo", "bar")) val clazz = Class.forName(methodName) val method = clazz.getMethod("main", classOf[Array[String]]) - method.invoke(null, Array[String]()) + method.invoke(null, Array[String]("1", "2")) end Test @experimental -class myMain extends MainAnnotation: +class myMain extends MainAnnotation[Make, Any]: import MainAnnotation.* - def command(info: CommandInfo, args: Array[String]): Command[Make, Any] = - def paramInfoString(paramInfo: ParameterInfo) = + def command(info: Info, args: Seq[String]): Option[Seq[String]] = + def paramInfoString(paramInfo: Parameter) = import paramInfo.* - s" ParameterInfo(name=\"$name\", typeName=\"$typeName\", hasDefault=$hasDefault, isVarargs=$isVarargs, documentation=\"$documentation\", annotations=$annotations)" + s" Parameter(name=\"$name\", typeName=\"$typeName\", hasDefault=$hasDefault, isVarargs=$isVarargs, documentation=\"$documentation\", annotations=$annotations)" println( s"""command( | ${args.mkString("Array(", ", ", ")")}, @@ -31,20 +31,19 @@ class myMain extends MainAnnotation: | "${info.documentation}", | ${info.parameters.map(paramInfoString).mkString("Seq(\n", ",\n", "\n )*")} |)""".stripMargin) - new Command[Make, Any]: - override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: Make[T]): () => T = - println(s"argGetter($idx, ${defaultArgument.map(_())})") - () => p.make - - override def varargGetter[T](using p: Make[T]): () => Seq[T] = - println("varargGetter()") - () => Seq(p.make, p.make) - - override def run(f: () => Any): Unit = - println("run()") - f() - println() - end command + Some(args) + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: Make[T]): () => T = + () => p.make + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: Make[T]): () => Seq[T] = + println("varargGetter()") + () => Seq(p.make, p.make) + + def run(f: () => Any): Unit = + println("run()") + f() + println() @experimental case class MyParamAnnot(n: Int) extends MainAnnotation.ParameterAnnotation diff --git a/tests/run/main-annotation-newMain.scala b/tests/run/main-annotation-newMain.scala index c2a538443b24..9e85d5f948cc 100644 --- a/tests/run/main-annotation-newMain.scala +++ b/tests/run/main-annotation-newMain.scala @@ -30,275 +30,288 @@ end Test @experimental -final class newMain extends MainAnnotation: +final class newMain extends MainAnnotation[FromString, Any]: import newMain._ import MainAnnotation._ - def command(info: CommandInfo, args: Array[String]): Command[FromString, Any] = - new Command[FromString, Any]: - - private inline val argMarker = "--" - private inline val shortArgMarker = "-" - - /** - * The name of the special argument to display the method's help. - * If one of the method's parameters is called the same, will be ignored. - */ - private inline val helpArg = "help" - - /** - * The short name of the special argument to display the method's help. - * If one of the method's parameters uses the same short name, will be ignored. - */ - private inline val shortHelpArg = 'h' - private var shortHelpIsOverridden = false - - private inline val maxUsageLineLength = 120 - - /** A map from argument canonical name (the name of the parameter in the method definition) to parameter informations */ - private val nameToParameterInfo: Map[String, ParameterInfo] = info.parameters.map(infos => infos.name -> infos).toMap - - private val (positionalArgs, byNameArgs, invalidByNameArgs, helpIsOverridden) = { - val namesToCanonicalName: Map[String, String] = info.parameters.flatMap( - infos => - var names = getAlternativeNames(infos) - val canonicalName = infos.name - if nameIsValid(canonicalName) then names = canonicalName +: names - names.map(_ -> canonicalName) - ).toMap - val shortNamesToCanonicalName: Map[Char, String] = info.parameters.flatMap( - infos => - var names = getShortNames(infos) - val canonicalName = infos.name - if shortNameIsValid(canonicalName) then names = canonicalName(0) +: names - names.map(_ -> canonicalName) - ).toMap - - val helpIsOverridden = namesToCanonicalName.exists((name, _) => name == helpArg) - shortHelpIsOverridden = shortNamesToCanonicalName.exists((name, _) => name == shortHelpArg) - - def getCanonicalArgName(arg: String): Option[String] = - if arg.startsWith(argMarker) && arg.length > argMarker.length then - namesToCanonicalName.get(arg.drop(argMarker.length)) - else if arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 then - shortNamesToCanonicalName.get(arg(shortArgMarker.length)) - else - None - - def isArgName(arg: String): Boolean = - val isFullName = arg.startsWith(argMarker) - val isShortName = arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 && shortNameIsValid(arg(shortArgMarker.length)) - isFullName || isShortName - - def recurse(remainingArgs: Seq[String], pa: mutable.Queue[String], bna: Seq[(String, String)], ia: Seq[String]): (mutable.Queue[String], Seq[(String, String)], Seq[String]) = - remainingArgs match { - case Seq() => - (pa, bna, ia) - case argName +: argValue +: rest if isArgName(argName) => - getCanonicalArgName(argName) match { - case Some(canonicalName) => recurse(rest, pa, bna :+ (canonicalName -> argValue), ia) - case None => recurse(rest, pa, bna, ia :+ argName) - } - case arg +: rest => - recurse(rest, pa :+ arg, bna, ia) - } - - val (pa, bna, ia) = recurse(args.toSeq, mutable.Queue.empty, Vector(), Vector()) - val nameToArgValues: Map[String, Seq[String]] = if bna.isEmpty then Map.empty else bna.groupMapReduce(_._1)(p => List(p._2))(_ ++ _) - (pa, nameToArgValues, ia, helpIsOverridden) - } + private inline val argMarker = "--" + private inline val shortArgMarker = "-" - /** A buffer for all errors */ - private val errors = new mutable.ArrayBuffer[String] + /** The name of the special argument to display the method's help. + * If one of the method's parameters is called the same, will be ignored. + */ + private inline val helpArg = "help" - /** Issue an error, and return an uncallable getter */ - private def error(msg: String): () => Nothing = - errors += msg - () => throw new AssertionError("trying to get invalid argument") + /** The short name of the special argument to display the method's help. + * If one of the method's parameters uses the same short name, will be ignored. + */ + private inline val shortHelpArg = 'h' - private inline def nameIsValid(name: String): Boolean = - name.length > 1 // TODO add more checks for illegal characters + private inline val maxUsageLineLength = 120 - private inline def shortNameIsValid(name: String): Boolean = - name.length == 1 && shortNameIsValid(name(0)) + private var info: Info = _ // TODO remove this var - private inline def shortNameIsValid(shortName: Char): Boolean = - ('A' <= shortName && shortName <= 'Z') || ('a' <= shortName && shortName <= 'z') - private def getNameWithMarker(name: String | Char): String = name match { - case c: Char => shortArgMarker + c - case s: String if shortNameIsValid(s) => shortArgMarker + s - case s => argMarker + s - } + /** A buffer for all errors */ + private val errors = new mutable.ArrayBuffer[String] - private def convert[T](argName: String, arg: String)(using p: FromString[T]): () => T = - p.fromStringOption(arg) match - case Some(t) => () => t - case None => error(s"invalid argument for $argName: $arg") - - private def usage(): Unit = - def argsUsage: Seq[String] = - for info <- info.parameters yield - val canonicalName = getNameWithMarker(info.name) - val shortNames = getShortNames(info).map(getNameWithMarker) - val alternativeNames = getAlternativeNames(info).map(getNameWithMarker) - val namesPrint = (canonicalName +: alternativeNames ++: shortNames).mkString("[", " | ", "]") - if info.isVarargs then s"[<${info.typeName}> [<${info.typeName}> [...]]]" - else if info.hasDefault then s"[$namesPrint <${info.typeName}>]" - else s"$namesPrint <${info.typeName}>" - end for - - def wrapArgumentUsages(argsUsage: Seq[String], maxLength: Int): Seq[String] = { - def recurse(args: Seq[String], currentLine: String, acc: Vector[String]): Seq[String] = - (args, currentLine) match { - case (Nil, "") => acc - case (Nil, l) => (acc :+ l) - case (arg +: t, "") => recurse(t, arg, acc) - case (arg +: t, l) if l.length + 1 + arg.length <= maxLength => recurse(t, s"$l $arg", acc) - case (arg +: t, l) => recurse(t, arg, acc :+ l) - } + /** Issue an error, and return an uncallable getter */ + private def error(msg: String): () => Nothing = + errors += msg + () => throw new AssertionError("trying to get invalid argument") - recurse(argsUsage, "", Vector()).toList - } + private def getAliases(param: Parameter): Seq[String] = + param.annotations.collect{ case a: Alias => a }.flatMap(_.aliases) - val usageBeginning = s"Usage: ${info.name} " - val argsOffset = usageBeginning.length - val usages = wrapArgumentUsages(argsUsage, maxUsageLineLength - argsOffset) + private def getAlternativeNames(param: Parameter): Seq[String] = + getAliases(param).filter(nameIsValid(_)) - println(usageBeginning + usages.mkString("\n" + " " * argsOffset)) - end usage + private def getShortNames(param: Parameter): Seq[Char] = + getAliases(param).filter(shortNameIsValid(_)).map(_(0)) - private def explain(): Unit = - inline def shiftLines(s: Seq[String], shift: Int): String = s.map(" " * shift + _).mkString("\n") + private inline def nameIsValid(name: String): Boolean = + name.length > 1 // TODO add more checks for illegal characters - def wrapLongLine(line: String, maxLength: Int): List[String] = { - def recurse(s: String, acc: Vector[String]): Seq[String] = - val lastSpace = s.trim.nn.lastIndexOf(' ', maxLength) - if ((s.length <= maxLength) || (lastSpace < 0)) - acc :+ s - else { - val (shortLine, rest) = s.splitAt(lastSpace) - recurse(rest.trim.nn, acc :+ shortLine) - } + private inline def shortNameIsValid(name: String): Boolean = + name.length == 1 && shortNameIsValidChar(name(0)) - recurse(line, Vector()).toList - } + private inline def shortNameIsValidChar(shortName: Char): Boolean = + ('A' <= shortName && shortName <= 'Z') || ('a' <= shortName && shortName <= 'z') - if (info.documentation.nonEmpty) - println(wrapLongLine(info.documentation, maxUsageLineLength).mkString("\n")) - if (nameToParameterInfo.nonEmpty) { - val argNameShift = 2 - val argDocShift = argNameShift + 2 - - println("Arguments:") - for info <- info.parameters do - val canonicalName = getNameWithMarker(info.name) - val shortNames = getShortNames(info).map(getNameWithMarker) - val alternativeNames = getAlternativeNames(info).map(getNameWithMarker) - val otherNames = (alternativeNames ++: shortNames) match { - case Seq() => "" - case names => names.mkString("(", ", ", ") ") - } - val argDoc = StringBuilder(" " * argNameShift) - argDoc.append(s"$canonicalName $otherNames- ${info.typeName}") - - if info.isVarargs then argDoc.append(" (vararg)") - else if info.hasDefault then argDoc.append(" (optional)") - - val doc = info.documentation - if (doc.nonEmpty) { - val shiftedDoc = - doc.split("\n").nn - .map(line => shiftLines(wrapLongLine(line.nn, maxUsageLineLength - argDocShift), argDocShift)) - .mkString("\n") - argDoc.append("\n").append(shiftedDoc) - } + private def getNameWithMarker(name: String | Char): String = name match { + case c: Char => shortArgMarker + c + case s: String if shortNameIsValid(s) => shortArgMarker + s + case s => argMarker + s + } - println(argDoc) - end for - } - end explain + private def getInvalidNames(param: Parameter): Seq[String | Char] = + getAliases(param).filter(name => !nameIsValid(name) && !shortNameIsValid(name)) - private def getAliases(paramInfos: ParameterInfo): Seq[String] = - paramInfos.annotations.collect{ case a: Alias => a }.flatMap(_.aliases) + def command(info: Info, args: Seq[String]): Option[Seq[String]] = + this.info = info - private def getAlternativeNames(paramInfos: ParameterInfo): Seq[String] = - getAliases(paramInfos).filter(nameIsValid(_)) + val namesToCanonicalName: Map[String, String] = info.parameters.flatMap( + infos => + val names = getAlternativeNames(infos) + val canonicalName = infos.name + if nameIsValid(canonicalName) then (canonicalName +: names).map(_ -> canonicalName) + else names.map(_ -> canonicalName) + ).toMap + val shortNamesToCanonicalName: Map[Char, String] = info.parameters.flatMap( + infos => + val names = getShortNames(infos) + val canonicalName = infos.name + if shortNameIsValid(canonicalName) then (canonicalName(0) +: names).map(_ -> canonicalName) + else names.map(_ -> canonicalName) + ).toMap - private def getShortNames(paramInfos: ParameterInfo): Seq[Char] = - getAliases(paramInfos).filter(shortNameIsValid(_)).map(_(0)) + val helpIsOverridden = namesToCanonicalName.exists((name, _) => name == helpArg) + val shortHelpIsOverridden = shortNamesToCanonicalName.exists((name, _) => name == shortHelpArg) - private def getInvalidNames(paramInfos: ParameterInfo): Seq[String | Char] = - getAliases(paramInfos).filter(name => !nameIsValid(name) && !shortNameIsValid(name)) + val (positionalArgs, byNameArgs, invalidByNameArgs) = { + def getCanonicalArgName(arg: String): Option[String] = + if arg.startsWith(argMarker) && arg.length > argMarker.length then + namesToCanonicalName.get(arg.drop(argMarker.length)) + else if arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 then + shortNamesToCanonicalName.get(arg(shortArgMarker.length)) + else + None + + def isArgName(arg: String): Boolean = + val isFullName = arg.startsWith(argMarker) + val isShortName = arg.startsWith(shortArgMarker) && arg.length == shortArgMarker.length + 1 && shortNameIsValidChar(arg(shortArgMarker.length)) + isFullName || isShortName + + def recurse(remainingArgs: Seq[String], pa: mutable.Queue[String], bna: Seq[(String, String)], ia: Seq[String]): (mutable.Queue[String], Seq[(String, String)], Seq[String]) = + remainingArgs match { + case Seq() => + (pa, bna, ia) + case argName +: argValue +: rest if isArgName(argName) => + getCanonicalArgName(argName) match { + case Some(canonicalName) => recurse(rest, pa, bna :+ (canonicalName -> argValue), ia) + case None => recurse(rest, pa, bna, ia :+ argName) + } + case arg +: rest => + recurse(rest, pa :+ arg, bna, ia) + } - override def argGetter[T](idx: Int, optDefaultGetter: Option[() => T])(using p: FromString[T]): () => T = - val name = info.parameters(idx).name - val parameterInfo = nameToParameterInfo(name) - // TODO: Decide which string is associated with this arg when constructing the command. - // Here we should only get the string for this argument, apply it to the parser and handle parsing errors. - // Should be able to get the argument from its index. - byNameArgs.get(name) match { + val (pa, bna, ia) = recurse(args.toSeq, mutable.Queue.empty, Vector(), Vector()) + val nameToArgValues: Map[String, Seq[String]] = if bna.isEmpty then Map.empty else bna.groupMapReduce(_._1)(p => List(p._2))(_ ++ _) + (pa, nameToArgValues, ia) + } + + val argStrings: Seq[Seq[String]] = + for paramInfo <- info.parameters yield { + if (paramInfo.isVarargs) { + val byNameGetters = byNameArgs.getOrElse(paramInfo.name, Seq()) + val positionalGetters = positionalArgs.removeAll() + // First take arguments passed by name, then those passed by position + byNameGetters ++ positionalGetters + } else { + byNameArgs.get(paramInfo.name) match case Some(Nil) => - throw AssertionError(s"$name present in byNameArgs, but it has no argument value") + throw AssertionError(s"${paramInfo.name} present in byNameArgs, but it has no argument value") case Some(argValues) => if argValues.length > 1 then // Do not accept multiple values // Remove this test to take last given argument - error(s"more than one value for $name: ${argValues.mkString(", ")}") + error(s"more than one value for ${paramInfo.name}: ${argValues.mkString(", ")}") + Nil else - convert(name, argValues.last) + List(argValues.last) case None => if positionalArgs.length > 0 then - convert(name, positionalArgs.dequeue) - else if optDefaultGetter.nonEmpty then - optDefaultGetter.get + List(positionalArgs.dequeue()) + else if paramInfo.hasDefault then + Nil else - error(s"missing argument for $name") + error(s"missing argument for ${paramInfo.name}") + Nil } - end argGetter - - override def varargGetter[T](using p: FromString[T]): () => Seq[T] = - val name = info.parameters.last.name - // TODO: Decide which strings are associated with the varargs when constructing the command. - // Here we should only get the strings for this argument, apply them to the parser and handle parsing errors. - // Should be able to get the argument from its index (last). - val byNameGetters = byNameArgs.getOrElse(name, Seq()).map(arg => convert(name, arg)) - val positionalGetters = positionalArgs.removeAll.map(arg => convert(name, arg)) - // First take arguments passed by name, then those passed by position - () => (byNameGetters ++ positionalGetters).map(_()) - - override def run(f: () => Any): Unit = - // Check aliases unicity - val nameAndCanonicalName = nameToParameterInfo.toList.flatMap { - case (canonicalName, infos) => (canonicalName +: getAlternativeNames(infos) ++: getShortNames(infos)).map(_ -> canonicalName) + } + + // Check aliases unicity + val nameAndCanonicalName = info.parameters.flatMap { + case paramInfo => (paramInfo.name +: getAlternativeNames(paramInfo) ++: getShortNames(paramInfo)).map(_ -> paramInfo.name) + } + val nameToCanonicalNames = nameAndCanonicalName.groupMap(_._1)(_._2) + + for (name, canonicalNames) <- nameToCanonicalNames if canonicalNames.length > 1 do + throw IllegalArgumentException(s"$name is used for multiple parameters: ${canonicalNames.mkString(", ")}") + + // Check aliases validity + val problematicNames = info.parameters.flatMap(getInvalidNames) + if problematicNames.length > 0 then + throw IllegalArgumentException(s"The following aliases are invalid: ${problematicNames.mkString(", ")}") + + // Handle unused and invalid args + for (remainingArg <- positionalArgs) error(s"unused argument: $remainingArg") + for (invalidArg <- invalidByNameArgs) error(s"unknown argument name: $invalidArg") + + val displayHelp = + (!helpIsOverridden && args.contains(getNameWithMarker(helpArg))) || + (!shortHelpIsOverridden && args.contains(getNameWithMarker(shortHelpArg))) + + if displayHelp then + usage() + println() + explain() + None + else if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + usage() + None + else + Some(argStrings.flatten) + end command + + private def usage(): Unit = + def argsUsage: Seq[String] = + for (infos <- info.parameters) + yield { + val canonicalName = getNameWithMarker(infos.name) + val shortNames = getShortNames(infos).map(getNameWithMarker) + val alternativeNames = getAlternativeNames(infos).map(getNameWithMarker) + val namesPrint = (canonicalName +: alternativeNames ++: shortNames).mkString("[", " | ", "]") + val shortTypeName = infos.typeName.split('.').last + if infos.isVarargs then s"[<$shortTypeName> [<$shortTypeName> [...]]]" + else if infos.hasDefault then s"[$namesPrint <$shortTypeName>]" + else s"$namesPrint <$shortTypeName>" + } + + def wrapArgumentUsages(argsUsage: Seq[String], maxLength: Int): Seq[String] = { + def recurse(args: Seq[String], currentLine: String, acc: Vector[String]): Seq[String] = + (args, currentLine) match { + case (Nil, "") => acc + case (Nil, l) => (acc :+ l) + case (arg +: t, "") => recurse(t, arg, acc) + case (arg +: t, l) if l.length + 1 + arg.length <= maxLength => recurse(t, s"$l $arg", acc) + case (arg +: t, l) => recurse(t, arg, acc :+ l) } - val nameToCanonicalNames = nameAndCanonicalName.groupMap(_._1)(_._2) - for (name, canonicalNames) <- nameToCanonicalNames if canonicalNames.length > 1 - do throw IllegalArgumentException(s"$name is used for multiple parameters: ${canonicalNames.mkString(", ")}") + recurse(argsUsage, "", Vector()).toList + } - // Check aliases validity - val problematicNames = nameToParameterInfo.toList.flatMap((_, infos) => getInvalidNames(infos)) - if problematicNames.length > 0 then throw IllegalArgumentException(s"The following aliases are invalid: ${problematicNames.mkString(", ")}") + val usageBeginning = s"Usage: ${info.name} " + val argsOffset = usageBeginning.length + val usages = wrapArgumentUsages(argsUsage, maxUsageLineLength - argsOffset) - // Handle unused and invalid args - for (remainingArg <- positionalArgs) error(s"unused argument: $remainingArg") - for (invalidArg <- invalidByNameArgs) error(s"unknown argument name: $invalidArg") + println(usageBeginning + usages.mkString("\n" + " " * argsOffset)) + end usage - val displayHelp = - (!helpIsOverridden && args.contains(getNameWithMarker(helpArg))) || (!shortHelpIsOverridden && args.contains(getNameWithMarker(shortHelpArg))) + private def explain(): Unit = + inline def shiftLines(s: Seq[String], shift: Int): String = s.map(" " * shift + _).mkString("\n") + + def wrapLongLine(line: String, maxLength: Int): List[String] = { + def recurse(s: String, acc: Vector[String]): Seq[String] = + val lastSpace = s.trim.nn.lastIndexOf(' ', maxLength) + if ((s.length <= maxLength) || (lastSpace < 0)) + acc :+ s + else { + val (shortLine, rest) = s.splitAt(lastSpace) + recurse(rest.trim.nn, acc :+ shortLine) + } + + recurse(line, Vector()).toList + } + + if (info.documentation.nonEmpty) + println(wrapLongLine(info.documentation, maxUsageLineLength).mkString("\n")) + if (info.parameters.nonEmpty) { + val argNameShift = 2 + val argDocShift = argNameShift + 2 + + println("Arguments:") + for infos <- info.parameters do + val canonicalName = getNameWithMarker(infos.name) + val shortNames = getShortNames(infos).map(getNameWithMarker) + val alternativeNames = getAlternativeNames(infos).map(getNameWithMarker) + val otherNames = (alternativeNames ++: shortNames) match { + case Seq() => "" + case names => names.mkString("(", ", ", ") ") + } + val argDoc = StringBuilder(" " * argNameShift) + argDoc.append(s"$canonicalName $otherNames- ${infos.typeName.split('.').last}") + if infos.isVarargs then argDoc.append(" (vararg)") + else if infos.hasDefault then argDoc.append(" (optional)") + + if (infos.documentation.nonEmpty) { + val shiftedDoc = + infos.documentation.split("\n").nn + .map(line => shiftLines(wrapLongLine(line.nn, maxUsageLineLength - argDocShift), argDocShift)) + .mkString("\n") + argDoc.append("\n").append(shiftedDoc) + } + + println(argDoc) + } + end explain + + private def convert[T](argName: String, arg: String, p: FromString[T]): () => T = + p.fromStringOption(arg) match + case Some(t) => () => t + case None => error(s"invalid argument for $argName: $arg") + + def argGetter[T](param: Parameter, arg: String, defaultArgument: Option[() => T])(using p: FromString[T]): () => T = { + if arg.nonEmpty then convert(param.name, arg, p) + else defaultArgument match + case Some(defaultGetter) => defaultGetter + case None => error(s"missing argument for ${param.name}") + } + + def varargGetter[T](param: Parameter, args: Seq[String])(using p: FromString[T]): () => Seq[T] = { + val getters = args.map(arg => convert(param.name, arg, p)) + () => getters.map(_()) + } + + def run(execProgram: () => Any): Unit = { + if errors.nonEmpty then + for msg <- errors do println(s"Error: $msg") + usage() + else + execProgram() + } - if displayHelp then - usage() - println() - explain() - else if errors.nonEmpty then - for msg <- errors do println(s"Error: $msg") - usage() - else - f() - end run - end command end newMain object newMain: