blob: 306a6ed9d1f4ff477dbda320d55bd9aa062d089e [file] [log] [blame]
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//
// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
// breakages. One benefit is that we match the IR generated by Enzyme only after running it
// through LLVM's O3 pipeline, which will remove most of the noise.
// However, our integration test could also be affected by changes in how rustc lowers MIR into
// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
// reduce this test to only match the first lines and the ret instructions.
#![feature(autodiff)]
use std::autodiff::autodiff_forward;
#[autodiff_forward(d_square3, Dual, DualOnly)]
#[autodiff_forward(d_square2, 4, Dual, DualOnly)]
#[autodiff_forward(d_square1, 4, Dual, Dual)]
#[no_mangle]
#[inline(never)]
fn square(x: &f32) -> f32 {
x * x
}
// d_square2
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
// CHECK-NEXT: ret [4 x float] %15
// CHECK-NEXT: }
// d_square3, the extra float is the original return value (x * x)
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
// CHECK-NEXT: %4 = fmul float %"_2'ipl", 2.000000e+00
// CHECK-NEXT: %5 = fmul fast float %4, %x.0.val
// CHECK-NEXT: %6 = insertvalue [4 x float] undef, float %5, 0
// CHECK-NEXT: %7 = fmul float %"_2'ipl1", 2.000000e+00
// CHECK-NEXT: %8 = fmul fast float %7, %x.0.val
// CHECK-NEXT: %9 = insertvalue [4 x float] %6, float %8, 1
// CHECK-NEXT: %10 = fmul float %"_2'ipl2", 2.000000e+00
// CHECK-NEXT: %11 = fmul fast float %10, %x.0.val
// CHECK-NEXT: %12 = insertvalue [4 x float] %9, float %11, 2
// CHECK-NEXT: %13 = fmul float %"_2'ipl3", 2.000000e+00
// CHECK-NEXT: %14 = fmul fast float %13, %x.0.val
// CHECK-NEXT: %15 = insertvalue [4 x float] %12, float %14, 3
// CHECK-NEXT: %16 = insertvalue { float, [4 x float] } undef, float %_0, 0
// CHECK-NEXT: %17 = insertvalue { float, [4 x float] } %16, [4 x float] %15, 1
// CHECK-NEXT: ret { float, [4 x float] } %17
// CHECK-NEXT: }
fn main() {
let x = std::hint::black_box(3.0);
let output = square(&x);
dbg!(&output);
assert_eq!(9.0, output);
dbg!(square(&x));
let mut df_dx1 = 1.0;
let mut df_dx2 = 2.0;
let mut df_dx3 = 3.0;
let mut df_dx4 = 0.0;
let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
dbg!(o1, o2, o3, o4);
let [output2, o1, o2, o3, o4] =
d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
dbg!(o1, o2, o3, o4);
assert_eq!(output, output2);
assert!((6.0 - o1).abs() < 1e-10);
assert!((12.0 - o2).abs() < 1e-10);
assert!((18.0 - o3).abs() < 1e-10);
assert!((0.0 - o4).abs() < 1e-10);
assert_eq!(1.0, df_dx1);
assert_eq!(2.0, df_dx2);
assert_eq!(3.0, df_dx3);
assert_eq!(0.0, df_dx4);
assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
}