Skip to content

Commit 0218840

Browse files
committed
Initial naive implementation using Symbols to represent autodiff modes (Forward, Reverse)
Since the mode is no longer part of `meta_item`, we must insert it manually (otherwise macro expansion with `#[rustc_autodiff]` won't work). This can be revised later if a more structured representation becomes necessary (using enums, annotated structs, etc). Some tests are currently failing. I'll address them next.
1 parent af9e91c commit 0218840

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

+20-8
Original file line numberDiff line numberDiff line change
@@ -257,29 +257,41 @@ mod llvm_enzyme {
257257
// create TokenStream from vec elemtents:
258258
// meta_item doesn't have a .tokens field
259259
let mut ts: Vec<TokenTree> = vec![];
260-
if meta_item_vec.len() < 2 {
261-
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
262-
// input and output args.
260+
if meta_item_vec.len() < 1 {
261+
// At the bare minimum, we need a fnc name.
263262
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
264263
return vec![item];
265264
}
266265

267-
meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
266+
let mode_symbol = match mode {
267+
DiffMode::Forward => sym::Forward,
268+
DiffMode::Reverse => sym::Reverse,
269+
_ => unreachable!("Unsupported mode: {:?}", mode),
270+
};
271+
272+
// Insert mode token
273+
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
274+
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
275+
ts.insert(
276+
1,
277+
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
278+
);
268279

269280
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
270281
// If it is not given, we default to 1 (scalar mode).
271282
let start_position;
272283
let kind: LitKind = LitKind::Integer;
273284
let symbol;
274-
if meta_item_vec.len() >= 3
275-
&& let Some(width) = width(&meta_item_vec[2])
285+
if meta_item_vec.len() >= 2
286+
&& let Some(width) = width(&meta_item_vec[1])
276287
{
277-
start_position = 3;
288+
start_position = 2;
278289
symbol = Symbol::intern(&width.to_string());
279290
} else {
280-
start_position = 2;
291+
start_position = 1;
281292
symbol = sym::integer(1);
282293
}
294+
283295
let l: Lit = Lit { kind, symbol, suffix: None };
284296
let t = Token::new(TokenKind::Literal(l), Span::default());
285297
let comma = Token::new(TokenKind::Comma, Span::default());

compiler/rustc_span/src/symbol.rs

+2
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ symbols! {
253253
FnMut,
254254
FnOnce,
255255
Formatter,
256+
Forward,
256257
From,
257258
FromIterator,
258259
FromResidual,
@@ -346,6 +347,7 @@ symbols! {
346347
Result,
347348
ResumeTy,
348349
Return,
350+
Reverse,
349351
Right,
350352
Rust,
351353
RustaceansAreAwesome,

0 commit comments

Comments
 (0)