|  | //@ compile-flags: -Zautodiff=Enable -C opt-level=3  -Clto=fat | 
|  | //@ no-prefer-dynamic | 
|  | //@ needs-enzyme | 
|  | // | 
|  | // In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many | 
|  | // breakages. One benefit is that we match the IR generated by Enzyme only after running it | 
|  | // through LLVM's O3 pipeline, which will remove most of the noise. | 
|  | // However, our integration test could also be affected by changes in how rustc lowers MIR into | 
|  | // LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should | 
|  | // reduce this test to only match the first lines and the ret instructions. | 
|  | // | 
|  | // The function tested here has 4 inputs and 5 outputs, so we could either call forward-mode | 
|  | // autodiff 4 times, or reverse mode 5 times. Since a forward-mode call is usually faster than | 
|  | // reverse mode, we prefer it here. This file also tests a new optimization (batch mode), which | 
|  | // allows us to call forward-mode autodiff only once, and get all 5 outputs in a single call. | 
|  | // | 
|  | // We support 2 different batch modes. `d_square2` has the same interface as scalar forward-mode, | 
|  | // but each shadow argument is `width` times larger (thus 16 and 20 elements here). | 
|  | // `d_square3` instead takes `width` (4) shadow arguments, which are all the same size as the | 
|  | // original function arguments. | 
|  |  | 
|  | #![feature(autodiff)] | 
|  |  | 
|  | use std::autodiff::autodiff_forward; | 
|  |  | 
|  | // CHECK: ; | 
|  | #[no_mangle] | 
|  | #[autodiff_forward(d_square1, Dual, Dual)] | 
|  | #[autodiff_forward(d_square2, 4, Dualv, Dualv)] | 
|  | #[autodiff_forward(d_square3, 4, Dual, Dual)] | 
|  | fn square(x: &[f32], y: &mut [f32]) { | 
|  | assert!(x.len() >= 4); | 
|  | assert!(y.len() >= 5); | 
|  | y[0] = 4.3 * x[0] + 1.2 * x[1] + 3.4 * x[2] + 2.1 * x[3]; | 
|  | y[1] = 2.3 * x[0] + 4.5 * x[1] + 1.7 * x[2] + 6.4 * x[3]; | 
|  | y[2] = 1.1 * x[0] + 3.3 * x[1] + 2.5 * x[2] + 4.7 * x[3]; | 
|  | y[3] = 5.2 * x[0] + 1.4 * x[1] + 2.6 * x[2] + 3.8 * x[3]; | 
|  | y[4] = 1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + 4.0 * x[3]; | 
|  | } | 
|  |  | 
|  | fn main() { | 
|  | let x1 = std::hint::black_box(vec![0.0, 1.0, 2.0, 3.0]); | 
|  |  | 
|  | let dx1 = std::hint::black_box(vec![1.0; 12]); | 
|  |  | 
|  | let z1 = std::hint::black_box(vec![1.0, 0.0, 0.0, 0.0]); | 
|  | let z2 = std::hint::black_box(vec![0.0, 1.0, 0.0, 0.0]); | 
|  | let z3 = std::hint::black_box(vec![0.0, 0.0, 1.0, 0.0]); | 
|  | let z4 = std::hint::black_box(vec![0.0, 0.0, 0.0, 1.0]); | 
|  |  | 
|  | let z5 = std::hint::black_box(vec![ | 
|  | 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, | 
|  | ]); | 
|  |  | 
|  | let mut y1 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut y2 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut y3 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut y4 = std::hint::black_box(vec![0.0; 5]); | 
|  |  | 
|  | let mut y5 = std::hint::black_box(vec![0.0; 5]); | 
|  |  | 
|  | let mut y6 = std::hint::black_box(vec![0.0; 5]); | 
|  |  | 
|  | let mut dy1_1 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut dy1_2 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut dy1_3 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut dy1_4 = std::hint::black_box(vec![0.0; 5]); | 
|  |  | 
|  | let mut dy2 = std::hint::black_box(vec![0.0; 20]); | 
|  |  | 
|  | let mut dy3_1 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut dy3_2 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut dy3_3 = std::hint::black_box(vec![0.0; 5]); | 
|  | let mut dy3_4 = std::hint::black_box(vec![0.0; 5]); | 
|  |  | 
|  | // scalar. | 
|  | d_square1(&x1, &z1, &mut y1, &mut dy1_1); | 
|  | d_square1(&x1, &z2, &mut y2, &mut dy1_2); | 
|  | d_square1(&x1, &z3, &mut y3, &mut dy1_3); | 
|  | d_square1(&x1, &z4, &mut y4, &mut dy1_4); | 
|  |  | 
|  | // assert y1 == y2 == y3 == y4 | 
|  | for i in 0..5 { | 
|  | assert_eq!(y1[i], y2[i]); | 
|  | assert_eq!(y1[i], y3[i]); | 
|  | assert_eq!(y1[i], y4[i]); | 
|  | } | 
|  |  | 
|  | // batch mode A) | 
|  | d_square2(&x1, &z5, &mut y5, &mut dy2); | 
|  |  | 
|  | // assert y1 == y2 == y3 == y4 == y5 | 
|  | for i in 0..5 { | 
|  | assert_eq!(y1[i], y5[i]); | 
|  | } | 
|  |  | 
|  | // batch mode B) | 
|  | d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4); | 
|  | for i in 0..5 { | 
|  | assert_eq!(y5[i], y6[i]); | 
|  | } | 
|  |  | 
|  | for i in 0..5 { | 
|  | assert_eq!(dy2[0..5][i], dy3_1[i]); | 
|  | assert_eq!(dy2[5..10][i], dy3_2[i]); | 
|  | assert_eq!(dy2[10..15][i], dy3_3[i]); | 
|  | assert_eq!(dy2[15..20][i], dy3_4[i]); | 
|  | } | 
|  | } |