Skip to content

Split autodiff into autodiff_forward and autodiff_reverse #140697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion compiler/rustc_builtin_macros/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ builtin_macros_assert_requires_expression = macro requires an expression as an a

builtin_macros_autodiff = autodiff must be applied to function
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
Expand Down
67 changes: 47 additions & 20 deletions compiler/rustc_builtin_macros/src/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,23 @@ mod llvm_enzyme {
ecx: &mut ExtCtxt<'_>,
meta_item: &ThinVec<MetaItemInner>,
has_ret: bool,
mode: DiffMode,
) -> AutoDiffAttrs {
let dcx = ecx.sess.dcx();
let mode = name(&meta_item[1]);
let Ok(mode) = DiffMode::from_str(&mode) else {
dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
return AutoDiffAttrs::error();
};

// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
let mut first_activity = 2;
let mut first_activity = 1;

let width = if let [_, _, x, ..] = &meta_item[..]
let width = if let [_, x, ..] = &meta_item[..]
&& let Some(x) = width(x)
{
first_activity = 3;
first_activity = 2;
match x.try_into() {
Ok(x) => x,
Err(_) => {
dcx.emit_err(errors::AutoDiffInvalidWidth {
span: meta_item[2].span(),
span: meta_item[1].span(),
width: x,
});
return AutoDiffAttrs::error();
Expand Down Expand Up @@ -165,6 +161,24 @@ mod llvm_enzyme {
ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
}

pub(crate) fn expand_forward(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
}

pub(crate) fn expand_reverse(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
item: Annotatable,
) -> Vec<Annotatable> {
expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
}

/// We expand the autodiff macro to generate a new placeholder function which passes
/// type-checking and can be called by users. The function body of the placeholder function will
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
Expand Down Expand Up @@ -198,11 +212,12 @@ mod llvm_enzyme {
/// ```
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
/// in CI.
pub(crate) fn expand(
pub(crate) fn expand_with_mode(
ecx: &mut ExtCtxt<'_>,
expand_span: Span,
meta_item: &ast::MetaItem,
mut item: Annotatable,
mode: DiffMode,
) -> Vec<Annotatable> {
if cfg!(not(llvm_enzyme)) {
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
Expand Down Expand Up @@ -243,29 +258,41 @@ mod llvm_enzyme {
// create TokenStream from vec elemtents:
// meta_item doesn't have a .tokens field
let mut ts: Vec<TokenTree> = vec![];
if meta_item_vec.len() < 2 {
// At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
// input and output args.
if meta_item_vec.len() < 1 {
// At the bare minimum, we need a fnc name.
dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
return vec![item];
}

meta_item_inner_to_ts(&meta_item_vec[1], &mut ts);
let mode_symbol = match mode {
DiffMode::Forward => sym::Forward,
DiffMode::Reverse => sym::Reverse,
_ => unreachable!("Unsupported mode: {:?}", mode),
};

// Insert mode token
let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
ts.insert(
1,
TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
);

// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
// If it is not given, we default to 1 (scalar mode).
let start_position;
let kind: LitKind = LitKind::Integer;
let symbol;
if meta_item_vec.len() >= 3
&& let Some(width) = width(&meta_item_vec[2])
if meta_item_vec.len() >= 2
&& let Some(width) = width(&meta_item_vec[1])
{
start_position = 3;
start_position = 2;
symbol = Symbol::intern(&width.to_string());
} else {
start_position = 2;
start_position = 1;
symbol = sym::integer(1);
}

let l: Lit = Lit { kind, symbol, suffix: None };
let t = Token::new(TokenKind::Literal(l), Span::default());
let comma = Token::new(TokenKind::Comma, Span::default());
Expand All @@ -287,7 +314,7 @@ mod llvm_enzyme {
ts.pop();
let ts: TokenStream = TokenStream::from_iter(ts);

let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
if !x.is_active() {
// We encountered an error, so we return the original item.
// This allows us to potentially parse other attributes.
Expand Down Expand Up @@ -966,4 +993,4 @@ mod llvm_enzyme {
}
}

pub(crate) use llvm_enzyme::expand;
pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};
8 changes: 0 additions & 8 deletions compiler/rustc_builtin_macros/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,6 @@ mod autodiff {
pub(crate) act: String,
}

#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_mode)]
pub(crate) struct AutoDiffInvalidMode {
#[primary_span]
pub(crate) span: Span,
pub(crate) mode: String,
}

#[derive(Diagnostic)]
#[diag(builtin_macros_autodiff_width)]
pub(crate) struct AutoDiffInvalidWidth {
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_builtin_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#![allow(rustc::diagnostic_outside_of_impl)]
#![allow(rustc::untranslatable_diagnostic)]
#![cfg_attr(bootstrap, feature(let_chains))]
#![cfg_attr(not(bootstrap), feature(autodiff))]
#![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
#![doc(rust_logo)]
#![feature(assert_matches)]
#![feature(autodiff)]
#![feature(box_patterns)]
#![feature(decl_macro)]
#![feature(if_let_guard)]
Expand Down Expand Up @@ -113,7 +113,8 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) {

register_attr! {
alloc_error_handler: alloc_error_handler::expand,
autodiff: autodiff::expand,
autodiff_forward: autodiff::expand_forward,
autodiff_reverse: autodiff::expand_reverse,
bench: test::expand_bench,
cfg_accessible: cfg_accessible::Expander,
cfg_eval: cfg_eval::expand,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_passes/src/check_attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
self.check_generic_attr(hir_id, attr, target, Target::Fn);
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
}
[sym::autodiff, ..] => {
[sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => {
self.check_autodiff(hir_id, attr, span, target)
}
[sym::coroutine, ..] => {
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ symbols! {
FnMut,
FnOnce,
Formatter,
Forward,
From,
FromIterator,
FromResidual,
Expand Down Expand Up @@ -346,6 +347,7 @@ symbols! {
Result,
ResumeTy,
Return,
Reverse,
Right,
Rust,
RustaceansAreAwesome,
Expand Down Expand Up @@ -528,7 +530,8 @@ symbols! {
audit_that,
augmented_assignments,
auto_traits,
autodiff,
autodiff_forward,
autodiff_reverse,
automatically_derived,
avx,
avx512_target_feature,
Expand Down
3 changes: 2 additions & 1 deletion library/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ pub mod assert_matches {

// We don't export this through #[macro_export] for now, to avoid breakage.
#[unstable(feature = "autodiff", issue = "124509")]
#[cfg(not(bootstrap))]
/// Unstable module containing the unstable `autodiff` macro.
pub mod autodiff {
#[unstable(feature = "autodiff", issue = "124509")]
pub use crate::macros::builtin::autodiff;
pub use crate::macros::builtin::{autodiff_forward, autodiff_reverse};
}

#[unstable(feature = "contracts", issue = "128044")]
Expand Down
40 changes: 30 additions & 10 deletions library/core/src/macros/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1519,20 +1519,40 @@ pub(crate) mod builtin {
($file:expr $(,)?) => {{ /* compiler built-in */ }};
}

/// the derivative of a given function in the forward mode of differentiation.
/// It may only be applied to a function.
///
/// The expected usage syntax is:
/// `#[autodiff_forward(NAME, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
///
/// - `NAME`: A string that represents a valid function name.
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[rustc_builtin_macro]
#[cfg(not(bootstrap))]
pub macro autodiff_forward($item:item) {
/* compiler built-in */
}

/// Automatic Differentiation macro which allows generating a new function to compute
/// the derivative of a given function. It may only be applied to a function.
/// The expected usage syntax is
/// `#[autodiff(NAME, MODE, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
/// where:
/// NAME is a string that represents a valid function name.
/// MODE is any of Forward, Reverse, ForwardFirst, ReverseFirst.
/// INPUT_ACTIVITIES consists of one valid activity for each input parameter.
/// OUTPUT_ACTIVITY must not be set if we implicitly return nothing (or explicitly return
/// `-> ()`). Otherwise it must be set to one of the allowed activities.
/// the derivative of a given function in the reverse mode of differentiation.
/// It may only be applied to a function.
///
/// The expected usage syntax is:
/// `#[autodiff_reverse(NAME, INPUT_ACTIVITIES, OUTPUT_ACTIVITY)]`
///
/// - `NAME`: A string that represents a valid function name.
/// - `INPUT_ACTIVITIES`: Specifies one valid activity for each input parameter.
/// - `OUTPUT_ACTIVITY`: Must not be set if the function implicitly returns nothing
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
#[unstable(feature = "autodiff", issue = "124509")]
#[allow_internal_unstable(rustc_attrs)]
#[rustc_builtin_macro]
pub macro autodiff($item:item) {
#[cfg(not(bootstrap))]
pub macro autodiff_reverse($item:item) {
/* compiler built-in */
}

Expand Down
7 changes: 5 additions & 2 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,12 @@
// tidy-alphabetical-start

// stabilization was reverted after it hit beta
#![cfg_attr(not(bootstrap), feature(autodiff))]
#![feature(alloc_error_handler)]
#![feature(allocator_internals)]
#![feature(allow_internal_unsafe)]
#![feature(allow_internal_unstable)]
#![feature(asm_experimental_arch)]
#![feature(autodiff)]
#![feature(cfg_sanitizer_cfi)]
#![feature(cfg_target_thread_local)]
#![feature(cfi_encoding)]
Expand Down Expand Up @@ -632,12 +632,15 @@ pub mod simd {
#[doc(inline)]
pub use crate::std_float::StdFloat;
}

#[unstable(feature = "autodiff", issue = "124509")]
#[cfg(not(bootstrap))]
/// This module provides support for automatic differentiation.
pub mod autodiff {
/// This macro handles automatic differentiation.
pub use core::autodiff::autodiff;
pub use core::autodiff::{autodiff_forward, autodiff_reverse};
}

#[stable(feature = "futures_api", since = "1.36.0")]
pub mod task {
//! Types and Traits for working with asynchronous tasks.
Expand Down
8 changes: 4 additions & 4 deletions tests/codegen/autodiff/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_forward;

#[autodiff(d_square3, Forward, Dual, DualOnly)]
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
#[autodiff_forward(d_square3, Dual, DualOnly)]
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
#[autodiff_forward(d_square1, 4, Dual, Dual)]
#[no_mangle]
fn square(x: &f32) -> f32 {
x * x
Expand Down
6 changes: 3 additions & 3 deletions tests/codegen/autodiff/identical_fnc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}

#[autodiff(d_square2, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square2, Duplicated, Active)]
fn square2(x: &f64) -> f64 {
x * x
}
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/autodiff/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/autodiff/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
//@ needs-enzyme
#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[autodiff(d_square, Reverse, Duplicated, Active)]
#[autodiff_reverse(d_square, Duplicated, Active)]
#[no_mangle]
fn square(x: &f64) -> f64 {
x * x
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/autodiff/sret.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

#![feature(autodiff)]

use std::autodiff::autodiff;
use std::autodiff::autodiff_reverse;

#[no_mangle]
#[autodiff(df, Reverse, Active, Active, Active)]
#[autodiff_reverse(df, Active, Active, Active)]
fn primal(x: f32, y: f32) -> f64 {
(x * x * y) as f64
}
Expand Down
2 changes: 1 addition & 1 deletion tests/pretty/autodiff/autodiff_forward.pp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

// Test that forward mode ad macros are expanded correctly.

use std::autodiff::autodiff;
use std::autodiff::{autodiff_forward, autodiff_reverse};

#[rustc_autodiff]
#[inline(never)]
Expand Down
Loading
Loading