Skip to content

Commit 34f259b

Browse files
committed
Split autodiff into autodiff_forward and autodiff_reverse
Pending fix. ``` error: cannot find a built-in macro with name `autodiff_forward` --> library\core\src\macros\mod.rs:1542:5 | 1542 | / pub macro autodiff_forward($item:item) { 1543 | | /* compiler built-in */ 1544 | | } | |_____^ error: cannot find a built-in macro with name `autodiff_reverse` --> library\core\src\macros\mod.rs:1549:5 | 1549 | / pub macro autodiff_reverse($item:item) { 1550 | | /* compiler built-in */ 1551 | | } | |_____^ error: could not compile `core` (lib) due to 2 previous errors ```
1 parent 427288b commit 34f259b

File tree

6 files changed

+49
-18
lines changed

6 files changed

+49
-18
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

+29-14
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,20 @@ mod llvm_enzyme {
8888
has_ret: bool,
8989
) -> AutoDiffAttrs {
9090
let dcx = ecx.sess.dcx();
91-
let mode = name(&meta_item[1]);
92-
let Ok(mode) = DiffMode::from_str(&mode) else {
93-
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
94-
return AutoDiffAttrs::error();
95-
};
9691

9792
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
9893
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
99-
let mut first_activity = 2;
94+
let mut first_activity = 1;
10095

101-
let width = if let [_, _, x, ..] = &meta_item[..]
96+
let width = if let [_, x, ..] = &meta_item[..]
10297
&& let Some(x) = width(x)
10398
{
104-
first_activity = 3;
99+
first_activity = 2;
105100
match x.try_into() {
106101
Ok(x) => x,
107102
Err(_) => {
108103
dcx.emit_err(errors::AutoDiffInvalidWidth {
109-
span: meta_item[2].span(),
104+
span: meta_item[1].span(),
110105
width: x,
111106
});
112107
return AutoDiffAttrs::error();
@@ -150,7 +145,7 @@ mod llvm_enzyme {
150145
};
151146

152147
AutoDiffAttrs {
153-
mode,
148+
mode: DiffMode::Error,
154149
width,
155150
ret_activity: *ret_activity,
156151
input_activity: input_activity.to_vec(),
@@ -165,6 +160,24 @@ mod llvm_enzyme {
165160
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
166161
}
167162

163+
pub(crate) fn expand_forward(
164+
ecx: &mut ExtCtxt<'_>,
165+
expand_span: Span,
166+
meta_item: &ast::MetaItem,
167+
item: Annotatable,
168+
) -> Vec<Annotatable> {
169+
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
170+
}
171+
172+
pub(crate) fn expand_reverse(
173+
ecx: &mut ExtCtxt<'_>,
174+
expand_span: Span,
175+
meta_item: &ast::MetaItem,
176+
item: Annotatable,
177+
) -> Vec<Annotatable> {
178+
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
179+
}
180+
168181
/// We expand the autodiff macro to generate a new placeholder function which passes
169182
/// type-checking and can be called by users. The function body of the placeholder function will
170183
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +211,12 @@ mod llvm_enzyme {
198211
/// ```
199212
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
200213
/// in CI.
201-
pub(crate) fn expand(
214+
pub(crate) fn expand_with_mode(
202215
ecx: &mut ExtCtxt<'_>,
203216
expand_span: Span,
204217
meta_item: &ast::MetaItem,
205218
mut item: Annotatable,
219+
mode: DiffMode,
206220
) -> Vec<Annotatable> {
207221
if cfg!(not(llvm_enzyme)) {
208222
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
@@ -287,7 +301,8 @@ mod llvm_enzyme {
287301
ts.pop();
288302
let ts: TokenStream = TokenStream::from_iter(ts);
289303

290-
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
304+
let mut x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
305+
x.mode = mode;
291306
if !x.is_active() {
292307
// We encountered an error, so we return the original item.
293308
// This allows us to potentially parse other attributes.
@@ -964,6 +979,6 @@ mod llvm_enzyme {
964979
trace!("Generated signature: {:?}", d_sig);
965980
(d_sig, new_inputs, idents, false)
966981
}
967-
}
968982

969-
pub(crate) use llvm_enzyme::expand;
983+
984+
pub(crate) use llvm_enzyme::{expand, expand_forward, expand_reverse};

compiler/rustc_builtin_macros/src/lib.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {
113113

114114
register_attr! {
115115
alloc_error_handler: alloc_error_handler::expand,
116-
autodiff: autodiff::expand,
116+
autodiff_forward: autodiff::expand_forward,
117+
autodiff_reverse: autodiff::expand_reverse,
117118
bench: test::expand_bench,
118119
cfg_accessible: cfg_accessible::Expander,
119120
cfg_eval: cfg_eval::expand,

compiler/rustc_passes/src/check_attr.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
255255
self.check_generic_attr(hir_id, attr, target, Target::Fn);
256256
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
257257
}
258-
[sym::autodiff, ..] => {
258+
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
259259
self.check_autodiff(hir_id, attr, span, target)
260260
}
261261
[sym::coroutine, ..] => {

compiler/rustc_span/src/symbol.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,8 @@ symbols! {
528528
audit_that,
529529
augmented_assignments,
530530
auto_traits,
531-
autodiff,
531+
autodiff_forward,
532+
autodiff_reverse,
532533
automatically_derived,
533534
avx,
534535
avx512_target_feature,

library/core/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ pub mod assert_matches {
229229
/// Unstable module containing the unstable `autodiff` macro.
230230
pub mod autodiff {
231231
#[unstable(feature = "autodiff", issue = "124509")]
232-
pub use crate::macros::builtin::autodiff;
232+
pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
233233
}
234234

235235
#[unstable(feature = "contracts", issue = "128044")]

library/core/src/macros/mod.rs

+14
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,20 @@ pub(crate) mod builtin {
15361536
/* compiler built-in */
15371537
}
15381538

1539+
#[unstable(feature = "autodiff", issue = "124509")]
1540+
#[allow_internal_unstable(rustc_attrs)]
1541+
#[rustc_builtin_macro]
1542+
pub macro autodiff_forward($item:item) {
1543+
/* compiler built-in */
1544+
}
1545+
1546+
#[unstable(feature = "autodiff", issue = "124509")]
1547+
#[allow_internal_unstable(rustc_attrs)]
1548+
#[rustc_builtin_macro]
1549+
pub macro autodiff_reverse($item:item) {
1550+
/* compiler built-in */
1551+
}
1552+
15391553
/// Asserts that a boolean expression is `true` at runtime.
15401554
///
15411555
/// This will invoke the [`panic!`] macro if the provided expression cannot be

0 commit comments

Comments
 (0)