| #pragma once |
| |
| extern int enzyme_interleave; |
| |
| extern int enzyme_dup; |
| extern int enzyme_dupnoneed; |
| extern int enzyme_out; |
| extern int enzyme_const; |
| |
| extern int enzyme_const_return; |
| extern int enzyme_active_return; |
| extern int enzyme_dup_return; |
| |
| extern int enzyme_primal_return; |
| extern int enzyme_noret; |
| |
| template<typename Return, typename... T> |
| Return __enzyme_autodiff(T...); |
| |
| template<typename Return, typename... T> |
| Return __enzyme_fwddiff(T...); |
| |
| #include <enzyme/tuple> |
| |
| namespace enzyme { |
| |
| #pragma clang diagnostic push |
| #pragma clang diagnostic ignored "-Wmissing-braces" |
| |
| struct nodiff{}; |
| |
| template<bool ReturnPrimal = false> |
| struct ReverseMode { |
| |
| }; |
| using Reverse = ReverseMode<false>; |
| using ReverseWithPrimal = ReverseMode<true>; |
| |
| struct ForwardMode { |
| |
| }; |
| using Forward = ForwardMode; |
| |
| template < typename T > |
| struct Active{ |
| static_assert( |
| !std::is_reference_v<T> && !std::is_pointer_v<T>, |
| "Reference/pointer active arguments don't make sense for AD!" |
| ); |
| T value; |
| Active(const T& v) : value(v) {}; |
| Active(T&& v) : value(std::move(v)) {}; |
| Active(const Active<T>&) = default; |
| Active(Active<T>&&) = default; |
| operator T&() { return value; } |
| }; |
| |
| namespace detail { |
| |
| /** |
| * @brief Helper base for all `Duplicated`-type arguments |
| * |
| * Due to essentially the same implementation, it prevents code |
| * duplication |
| */ |
| template < typename T > |
| struct DupArg { |
| // Argument as regular value. Also works for pointer types |
| T value; |
| T shadow; |
| DupArg(const T& v, const T& s) : value(v), shadow(s) {}; |
| DupArg(T&& v, T&& s) : value(v), shadow(s) {}; |
| DupArg(const DupArg<T>&) = default; |
| DupArg(DupArg<T>&&) = default; |
| }; |
| |
| template < typename T > |
| struct DupArg<T&> { |
| T* value; |
| T* shadow; |
| DupArg(T& v, T& s) : value(&v), shadow(&s) {} |
| DupArg(const DupArg<T&>&) = default; |
| DupArg(DupArg<T&>&&) = default; |
| }; |
| |
| } // <-- namespace detail |
| |
| template < typename T > |
| struct Duplicated : public detail::DupArg<T> { |
| using detail::DupArg<T>::DupArg; |
| }; |
| |
| template < typename T > |
| struct DuplicatedNoNeed : public detail::DupArg<T> { |
| using detail::DupArg<T>::DupArg; |
| }; |
| |
| template < typename T > |
| struct Const{ |
| T value; |
| Const(const T& v) : value(v) {}; |
| Const(T&& v) : value(std::move(v)) {}; |
| Const(const Const<T>&) = default; |
| Const(Const<T>&&) = default; |
| operator T&() { return value; } |
| }; |
| |
| template < typename T > |
| struct Const<T&>{ |
| T* value; |
| Const(T& v) : value(&v) {} |
| Const(const Const<T&>&) = default; |
| Const(Const<T&>&&) = default; |
| operator T&() { return *value; } |
| }; |
| |
| // CTAD available in C++17 or later |
| #if __cplusplus >= 201703L |
| template < typename T > |
| Active(T) -> Active<T>; |
| |
| template < typename T > |
| Const(T) -> Const<T>; |
| |
| template < typename T > |
| Duplicated(T,T) -> Duplicated<T>; |
| |
| template < typename T > |
| DuplicatedNoNeed(T,T) -> DuplicatedNoNeed<T>; |
| #endif |
| |
| template < typename T > |
| struct type_info { |
| static constexpr bool is_active = false; |
| static constexpr bool is_duplicated = false; |
| using type = nodiff; |
| }; |
| |
| template < typename T > |
| struct type_info < Active<T> >{ |
| static constexpr bool is_active = true; |
| static constexpr bool is_duplicated = false; |
| using type = T; |
| }; |
| |
| template < typename T > |
| struct type_info < Duplicated<T> >{ |
| static constexpr bool is_active = false; |
| static constexpr bool is_duplicated = true; |
| using type = T; |
| }; |
| |
| template < typename T > |
| struct type_info < DuplicatedNoNeed<T> >{ |
| static constexpr bool is_active = false; |
| static constexpr bool is_duplicated = true; |
| using type = T; |
| }; |
| |
| template < typename ... T > |
| struct concatenated; |
| |
| template < typename ... S, typename T, typename ... rest > |
| struct concatenated < tuple < S ... >, T, rest ... > { |
| using type = typename concatenated< tuple< S ..., T>, rest ... >::type; |
| }; |
| |
| template < typename T > |
| struct concatenated < T > { |
| using type = T; |
| }; |
| |
| // Yikes! |
| // slightly cleaner in C++20, with std::remove_cvref |
| template < typename ... T > |
| struct autodiff_return; |
| |
| template < typename RetType, typename ... T > |
| struct autodiff_return<ReverseMode<false>, RetType, T...> |
| { |
| using type = tuple<typename concatenated< tuple< >, |
| typename type_info< |
| typename remove_cvref< T >::type |
| >::type ... |
| >::type>; |
| }; |
| |
| template < typename RetType, typename ... T > |
| struct autodiff_return<ReverseMode<true>, RetType, T...> |
| { |
| using type = tuple< |
| typename type_info<RetType>::type, |
| typename concatenated< tuple< >, |
| typename type_info< |
| typename remove_cvref< T >::type |
| >::type ... |
| >::type |
| >; |
| }; |
| |
| template < typename T0, typename ... T > |
| struct autodiff_return<ForwardMode, Const<T0>, T...> |
| { |
| using type = tuple<T0>; |
| }; |
| |
| template < typename T0, typename ... T > |
| struct autodiff_return<ForwardMode, Duplicated<T0>, T...> |
| { |
| using type = tuple<T0, T0>; |
| }; |
| |
| template < typename T0, typename ... T > |
| struct autodiff_return<ForwardMode, DuplicatedNoNeed<T0>, T...> |
| { |
| using type = tuple<T0>; |
| }; |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_args(const enzyme::Duplicated<T> & arg) { |
| return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_args(const enzyme::DuplicatedNoNeed<T> & arg) { |
| return enzyme::tuple{enzyme_dupnoneed, arg.value, arg.shadow}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_args(const enzyme::Active<T> & arg) { |
| return enzyme::tuple<int, T>{enzyme_out, arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_args(const enzyme::Const<T> & arg) { |
| return enzyme::tuple{enzyme_const, arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_primals(const enzyme::Duplicated<T> & arg) { |
| return enzyme::tuple{enzyme_dup, arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_primals(const enzyme::DuplicatedNoNeed<T> & arg) { |
| return enzyme::tuple{enzyme_dupnoneed, arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_primals(const enzyme::Active<T> & arg) { |
| return enzyme::tuple<int, T>{enzyme_out, arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_primals(const enzyme::Const<T> & arg) { |
| return enzyme::tuple{enzyme_const, arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_shadows(const enzyme::Duplicated<T> & arg) { |
| return enzyme::tuple{arg.shadow}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_shadows(const enzyme::DuplicatedNoNeed<T> & arg) { |
| return enzyme::tuple{arg.shadow}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_shadows(const enzyme::Active<T> & arg) { |
| return enzyme::tuple{}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto expand_shadows(const enzyme::Const<T> & arg) { |
| return enzyme::tuple{}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto primal_args(const enzyme::Duplicated<T> & arg) { |
| return enzyme::tuple<T>{arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto primal_args(const enzyme::DuplicatedNoNeed<T> & arg) { |
| return enzyme::tuple<T>{arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto primal_args(const enzyme::Active<T> & arg) { |
| return enzyme::tuple<T>{arg.value}; |
| } |
| |
| template < typename T > |
| __attribute__((always_inline)) |
| auto primal_args(const enzyme::Const<T> & arg) { |
| return enzyme::tuple<T>{arg.value}; |
| } |
| |
| template < template <typename> class Template, typename Arg > |
| Arg arg_type(const Template<Arg>&); |
| |
| namespace detail { |
| template < typename RetType, typename ... T > |
| struct function_type |
| { |
| using type = RetType(T...); |
| }; |
| |
| template<typename function, typename prevfunc> |
| struct templated_call { |
| |
| }; |
| |
| template<typename function, typename RT, typename ...T> |
| struct templated_call<function, RT(T...)> { |
| static RT wrap(function* __restrict__ f, T... args) { |
| return (*f)(args...); |
| } |
| }; |
| |
| |
| |
| |
| template<typename T> |
| __attribute__((always_inline)) |
| constexpr decltype(auto) push_return_last(T &&t); |
| |
| template<typename ...T> |
| __attribute__((always_inline)) |
| constexpr decltype(auto) push_return_last(tuple<tuple<T...>> &&t) { |
| return tuple<tuple<T...>>{get<0>(t)}; |
| } |
| |
| template<typename ...T, typename R> |
| __attribute__((always_inline)) |
| constexpr decltype(auto) push_return_last(tuple<R, tuple<T...>> &&t) { |
| return tuple{get<1>(t), get<0>(t)}; |
| } |
| |
| template <typename Mode> |
| struct autodiff_apply {}; |
| |
| template <bool Mode> |
| struct autodiff_apply<ReverseMode<Mode>> { |
| template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs> |
| __attribute__((always_inline)) |
| static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) { |
| return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, args..., enzyme::get<I>(impl::forward<Tuple>(t))...)); |
| } |
| }; |
| |
| template <> |
| struct autodiff_apply<ForwardMode> { |
| template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs> |
| __attribute__((always_inline)) |
| static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) { |
| return __enzyme_fwddiff<return_type>(f, ret_attr, args..., enzyme::get<I>(impl::forward<Tuple>(t))...); |
| } |
| }; |
| |
| template <typename function, class Tuple, std::size_t... I> |
| __attribute__((always_inline)) |
| constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence<I...>) { |
| return f(enzyme::get<I>(impl::forward<Tuple>(t))...); |
| } |
| |
| template < typename Mode, typename T > |
| struct default_ret_activity { |
| using type = Const<T>; |
| }; |
| |
| template <bool prim> |
| struct default_ret_activity<ReverseMode<prim>, float> { |
| using type = Active<float>; |
| }; |
| |
| template <bool prim> |
| struct default_ret_activity<ReverseMode<prim>, double> { |
| using type = Active<double>; |
| }; |
| |
| template<> |
| struct default_ret_activity<ForwardMode, float> { |
| using type = DuplicatedNoNeed<float>; |
| }; |
| |
| template<> |
| struct default_ret_activity<ForwardMode, double> { |
| using type = DuplicatedNoNeed<double>; |
| }; |
| |
| template < typename T > |
| struct ret_global; |
| |
| template<typename T> |
| struct ret_global<Const<T>> { |
| static constexpr int* value = &enzyme_const_return; |
| }; |
| |
| template<typename T> |
| struct ret_global<Active<T>> { |
| static constexpr int* value = &enzyme_active_return; |
| }; |
| |
| template<typename T> |
| struct ret_global<Duplicated<T>> { |
| static constexpr int* value = &enzyme_dup_return; |
| }; |
| |
| template<typename T> |
| struct ret_global<DuplicatedNoNeed<T>> { |
| static constexpr int* value = &enzyme_dup_return; |
| }; |
| |
| template<typename Mode, typename RetAct> |
| struct ret_used; |
| |
| template<typename RetAct> |
| struct ret_used<ReverseMode<true>, RetAct> { |
| static constexpr int* value = &enzyme_primal_return; |
| }; |
| |
| template<typename RetAct> |
| struct ret_used<ReverseMode<false>, RetAct> { |
| static constexpr int* value = &enzyme_noret; |
| }; |
| |
| template<typename T> |
| struct ret_used<ForwardMode, DuplicatedNoNeed<T>> { |
| static constexpr int* value = &enzyme_noret; |
| }; |
| template<typename T> |
| struct ret_used<ForwardMode, Const<T>> { |
| static constexpr int* value = &enzyme_primal_return; |
| }; |
| template<typename T> |
| struct ret_used<ForwardMode, Duplicated<T>> { |
| static constexpr int* value = &enzyme_primal_return; |
| }; |
| |
| template <typename T, typename... Args> |
| struct verify_dup_args { |
| static constexpr bool value = true; |
| }; |
| |
| template <typename... Args> |
| struct verify_dup_args<enzyme::Reverse, Args...> { |
| static constexpr bool value = ( |
| ( |
| !type_info<Args>::is_duplicated |
| || std::is_reference_v<typename type_info<Args>::type> |
| || std::is_pointer_v<typename type_info<Args>::type> |
| ) && ... |
| ); |
| }; |
| |
| template<class T> |
| struct is_function_or_function_ptr : std::is_function<T> {}; |
| |
| template<class T> |
| struct is_function_or_function_ptr<T*> : std::is_function<T> {}; |
| } // namespace detail |
| |
| template < typename return_type, typename function, typename ... enz_arg_types > |
| __attribute__((always_inline)) |
| auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { |
| using Tuple = enzyme::tuple< enz_arg_types ... >; |
| return detail::primal_apply_impl<return_type>(std::move(f), impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{}); |
| } |
| |
| template < typename function, typename ... arg_types> |
| auto primal_call(function && f, arg_types && ... args) { |
| return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...)); |
| } |
| |
| template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<!detail::is_function_or_function_ptr<typename remove_cvref< function >::type>::value, int> = 0> |
| __attribute__((always_inline)) |
| auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { |
| using Tuple = enzyme::tuple< enz_arg_types ... >; |
| return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)&detail::templated_call<function, functy>::wrap, detail::ret_global<RetActivity>::value, |
| impl::forward<Tuple>(arg_tup), |
| std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{}, &enzyme_const, &f); |
| } |
| |
| template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<detail::is_function_or_function_ptr<typename remove_cvref< function >::type>::value, int> = 0> |
| __attribute__((always_inline)) |
| auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { |
| using Tuple = enzyme::tuple< enz_arg_types ... >; |
| return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)static_cast<functy*>(f), detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{}); |
| } |
| |
| template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> |
| __attribute__((always_inline)) |
| auto autodiff(function && f, arg_types && ... args) { |
| static_assert( |
| detail::verify_dup_args<DiffMode, arg_types...>::value, |
| "Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make " |
| "sense for Reverse mode AD" |
| ); |
| using primal_return_type = decltype(f(arg_type(args)...)); |
| using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type; |
| using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type; |
| return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_primals(args)..., enzyme::tuple{enzyme_interleave}, expand_shadows(args)...)); |
| } |
| |
| template < typename DiffMode, typename function, typename ... arg_types> |
| __attribute__((always_inline)) |
| auto autodiff(function && f, arg_types && ... args) { |
| static_assert( |
| detail::verify_dup_args<DiffMode, arg_types...>::value, |
| "Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make " |
| "sense for Reverse mode AD" |
| ); |
| using primal_return_type = decltype(f(arg_type(args)...)); |
| using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type; |
| using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type; |
| using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type; |
| return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_primals(args)..., enzyme::tuple{enzyme_interleave}, expand_shadows(args)...)); |
| } |
| #pragma clang diagnostic pop |
| |
| } // namespace enzyme |