@@ -88,25 +88,20 @@ mod llvm_enzyme {
88
88
has_ret : bool ,
89
89
) -> AutoDiffAttrs {
90
90
let dcx = ecx. sess . dcx ( ) ;
91
- let mode = name ( & meta_item[ 1 ] ) ;
92
- let Ok ( mode) = DiffMode :: from_str ( & mode) else {
93
- dcx. emit_err ( errors:: AutoDiffInvalidMode { span : meta_item[ 1 ] . span ( ) , mode } ) ;
94
- return AutoDiffAttrs :: error ( ) ;
95
- } ;
96
91
97
92
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
98
93
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
99
- let mut first_activity = 2 ;
94
+ let mut first_activity = 1 ;
100
95
101
- let width = if let [ _, _ , x, ..] = & meta_item[ ..]
96
+ let width = if let [ _, x, ..] = & meta_item[ ..]
102
97
&& let Some ( x) = width ( x)
103
98
{
104
- first_activity = 3 ;
99
+ first_activity = 2 ;
105
100
match x. try_into ( ) {
106
101
Ok ( x) => x,
107
102
Err ( _) => {
108
103
dcx. emit_err ( errors:: AutoDiffInvalidWidth {
109
- span : meta_item[ 2 ] . span ( ) ,
104
+ span : meta_item[ 1 ] . span ( ) ,
110
105
width : x,
111
106
} ) ;
112
107
return AutoDiffAttrs :: error ( ) ;
@@ -150,7 +145,7 @@ mod llvm_enzyme {
150
145
} ;
151
146
152
147
AutoDiffAttrs {
153
- mode,
148
+ mode : DiffMode :: Error ,
154
149
width,
155
150
ret_activity : * ret_activity,
156
151
input_activity : input_activity. to_vec ( ) ,
@@ -165,6 +160,24 @@ mod llvm_enzyme {
165
160
ts. push ( TokenTree :: Token ( comma. clone ( ) , Spacing :: Alone ) ) ;
166
161
}
167
162
163
+ pub ( crate ) fn expand_forward (
164
+ ecx : & mut ExtCtxt < ' _ > ,
165
+ expand_span : Span ,
166
+ meta_item : & ast:: MetaItem ,
167
+ item : Annotatable ,
168
+ ) -> Vec < Annotatable > {
169
+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Forward )
170
+ }
171
+
172
+ pub ( crate ) fn expand_reverse (
173
+ ecx : & mut ExtCtxt < ' _ > ,
174
+ expand_span : Span ,
175
+ meta_item : & ast:: MetaItem ,
176
+ item : Annotatable ,
177
+ ) -> Vec < Annotatable > {
178
+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Reverse )
179
+ }
180
+
168
181
/// We expand the autodiff macro to generate a new placeholder function which passes
169
182
/// type-checking and can be called by users. The function body of the placeholder function will
170
183
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +211,12 @@ mod llvm_enzyme {
198
211
/// ```
199
212
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
200
213
/// in CI.
201
- pub ( crate ) fn expand (
214
+ pub ( crate ) fn expand_with_mode (
202
215
ecx : & mut ExtCtxt < ' _ > ,
203
216
expand_span : Span ,
204
217
meta_item : & ast:: MetaItem ,
205
218
mut item : Annotatable ,
219
+ mode : DiffMode ,
206
220
) -> Vec < Annotatable > {
207
221
if cfg ! ( not( llvm_enzyme) ) {
208
222
ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffSupportNotBuild { span : meta_item. span } ) ;
@@ -287,7 +301,8 @@ mod llvm_enzyme {
287
301
ts. pop ( ) ;
288
302
let ts: TokenStream = TokenStream :: from_iter ( ts) ;
289
303
290
- let x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
304
+ let mut x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
305
+ x. mode = mode;
291
306
if !x. is_active ( ) {
292
307
// We encountered an error, so we return the original item.
293
308
// This allows us to potentially parse other attributes.
@@ -964,6 +979,6 @@ mod llvm_enzyme {
964
979
trace ! ( "Generated signature: {:?}" , d_sig) ;
965
980
( d_sig, new_inputs, idents, false )
966
981
}
967
- }
968
982
969
- pub ( crate ) use llvm_enzyme:: expand;
983
+
984
+ pub ( crate ) use llvm_enzyme:: { expand, expand_forward, expand_reverse} ;
0 commit comments