blob: 79951db94e21514c0f1a6d500cbfd9c6011d24f3 [file] [log] [blame] [edit]
// RUN: %eopt --split-input-file --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s
func.func @select(%c: i1, %a: f64, %b: f64) -> f64 {
%res = arith.select %c, %a, %b : f64
return %res : f64
}
func.func @dselect(%c: i1, %a: f64, %b: f64, %dr: f64) -> (f64, f64) {
%0:2 = enzyme.autodiff @select(%c, %a, %b, %dr)
{
activity=[#enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>],
ret_activity=[#enzyme<activity enzyme_activenoneed>]
} : (i1, f64, f64, f64) -> (f64, f64)
return %0#0, %0#1 : f64, f64
}
// CHECK: func.func private @diffeselect(%[[c:.+]]: i1, %[[a:.+]]: f64, %[[b:.+]]: f64, %[[dr:.+]]: f64) -> (f64, f64) {
// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: %[[da:.+]] = arith.select %[[c]], %[[dr]], %[[zero]] : f64
// CHECK-NEXT: %[[db:.+]] = arith.select %[[c]], %[[zero]], %[[dr]] : f64
// CHECK-NEXT: return %[[da]], %[[db]] : f64, f64
// CHECK-NEXT: }