blob: 7005d74e7666e517803dadbb25f26630b5841dc8 [file] [log] [blame] [edit]
#pragma once
extern int enzyme_interleave;
extern int enzyme_dup;
extern int enzyme_dupnoneed;
extern int enzyme_out;
extern int enzyme_const;
extern int enzyme_const_return;
extern int enzyme_active_return;
extern int enzyme_dup_return;
extern int enzyme_primal_return;
extern int enzyme_noret;
template<typename Return, typename... T>
Return __enzyme_autodiff(T...);
template<typename Return, typename... T>
Return __enzyme_fwddiff(T...);
#include <enzyme/tuple>
namespace enzyme {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wmissing-braces"
struct nodiff{};
template<bool ReturnPrimal = false>
struct ReverseMode {
};
using Reverse = ReverseMode<false>;
using ReverseWithPrimal = ReverseMode<true>;
struct ForwardMode {
};
using Forward = ForwardMode;
template < typename T >
struct Active{
static_assert(
!std::is_reference_v<T> && !std::is_pointer_v<T>,
"Reference/pointer active arguments don't make sense for AD!"
);
T value;
Active(const T& v) : value(v) {};
Active(T&& v) : value(std::move(v)) {};
Active(const Active<T>&) = default;
Active(Active<T>&&) = default;
operator T&() { return value; }
};
namespace detail {
/**
* @brief Helper base for all `Duplicated`-type arguments
*
* Due to essentially the same implementation, it prevents code
* duplication
*/
template < typename T >
struct DupArg {
// Argument as regular value. Also works for pointer types
T value;
T shadow;
DupArg(const T& v, const T& s) : value(v), shadow(s) {};
DupArg(T&& v, T&& s) : value(v), shadow(s) {};
DupArg(const DupArg<T>&) = default;
DupArg(DupArg<T>&&) = default;
};
template < typename T >
struct DupArg<T&> {
T* value;
T* shadow;
DupArg(T& v, T& s) : value(&v), shadow(&s) {}
DupArg(const DupArg<T&>&) = default;
DupArg(DupArg<T&>&&) = default;
};
} // <-- namespace detail
template < typename T >
struct Duplicated : public detail::DupArg<T> {
using detail::DupArg<T>::DupArg;
};
template < typename T >
struct DuplicatedNoNeed : public detail::DupArg<T> {
using detail::DupArg<T>::DupArg;
};
template < typename T >
struct Const{
T value;
Const(const T& v) : value(v) {};
Const(T&& v) : value(std::move(v)) {};
Const(const Const<T>&) = default;
Const(Const<T>&&) = default;
operator T&() { return value; }
};
template < typename T >
struct Const<T&>{
T* value;
Const(T& v) : value(&v) {}
Const(const Const<T&>&) = default;
Const(Const<T&>&&) = default;
operator T&() { return *value; }
};
// CTAD available in C++17 or later
#if __cplusplus >= 201703L
template < typename T >
Active(T) -> Active<T>;
template < typename T >
Const(T) -> Const<T>;
template < typename T >
Duplicated(T,T) -> Duplicated<T>;
template < typename T >
DuplicatedNoNeed(T,T) -> DuplicatedNoNeed<T>;
#endif
template < typename T >
struct type_info {
static constexpr bool is_active = false;
static constexpr bool is_duplicated = false;
using type = nodiff;
};
template < typename T >
struct type_info < Active<T> >{
static constexpr bool is_active = true;
static constexpr bool is_duplicated = false;
using type = T;
};
template < typename T >
struct type_info < Duplicated<T> >{
static constexpr bool is_active = false;
static constexpr bool is_duplicated = true;
using type = T;
};
template < typename T >
struct type_info < DuplicatedNoNeed<T> >{
static constexpr bool is_active = false;
static constexpr bool is_duplicated = true;
using type = T;
};
template < typename ... T >
struct concatenated;
template < typename ... S, typename T, typename ... rest >
struct concatenated < tuple < S ... >, T, rest ... > {
using type = typename concatenated< tuple< S ..., T>, rest ... >::type;
};
template < typename T >
struct concatenated < T > {
using type = T;
};
// Yikes!
// slightly cleaner in C++20, with std::remove_cvref
template < typename ... T >
struct autodiff_return;
template < typename RetType, typename ... T >
struct autodiff_return<ReverseMode<false>, RetType, T...>
{
using type = tuple<typename concatenated< tuple< >,
typename type_info<
typename remove_cvref< T >::type
>::type ...
>::type>;
};
template < typename RetType, typename ... T >
struct autodiff_return<ReverseMode<true>, RetType, T...>
{
using type = tuple<
typename type_info<RetType>::type,
typename concatenated< tuple< >,
typename type_info<
typename remove_cvref< T >::type
>::type ...
>::type
>;
};
template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, Const<T0>, T...>
{
using type = tuple<T0>;
};
template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, Duplicated<T0>, T...>
{
using type = tuple<T0, T0>;
};
template < typename T0, typename ... T >
struct autodiff_return<ForwardMode, DuplicatedNoNeed<T0>, T...>
{
using type = tuple<T0>;
};
template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple{enzyme_dup, arg.value, arg.shadow};
}
template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple{enzyme_dupnoneed, arg.value, arg.shadow};
}
template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::Active<T> & arg) {
return enzyme::tuple<int, T>{enzyme_out, arg.value};
}
template < typename T >
__attribute__((always_inline))
auto expand_args(const enzyme::Const<T> & arg) {
return enzyme::tuple{enzyme_const, arg.value};
}
template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple{enzyme_dup, arg.value};
}
template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple{enzyme_dupnoneed, arg.value};
}
template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::Active<T> & arg) {
return enzyme::tuple<int, T>{enzyme_out, arg.value};
}
template < typename T >
__attribute__((always_inline))
auto expand_primals(const enzyme::Const<T> & arg) {
return enzyme::tuple{enzyme_const, arg.value};
}
template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple{arg.shadow};
}
template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple{arg.shadow};
}
template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::Active<T> & arg) {
return enzyme::tuple{};
}
template < typename T >
__attribute__((always_inline))
auto expand_shadows(const enzyme::Const<T> & arg) {
return enzyme::tuple{};
}
template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Duplicated<T> & arg) {
return enzyme::tuple<T>{arg.value};
}
template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::DuplicatedNoNeed<T> & arg) {
return enzyme::tuple<T>{arg.value};
}
template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Active<T> & arg) {
return enzyme::tuple<T>{arg.value};
}
template < typename T >
__attribute__((always_inline))
auto primal_args(const enzyme::Const<T> & arg) {
return enzyme::tuple<T>{arg.value};
}
template < template <typename> class Template, typename Arg >
Arg arg_type(const Template<Arg>&);
namespace detail {
template < typename RetType, typename ... T >
struct function_type
{
using type = RetType(T...);
};
template<typename function, typename prevfunc>
struct templated_call {
};
template<typename function, typename RT, typename ...T>
struct templated_call<function, RT(T...)> {
static RT wrap(function* __restrict__ f, T... args) {
return (*f)(args...);
}
};
template<typename T>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(T &&t);
template<typename ...T>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(tuple<tuple<T...>> &&t) {
return tuple<tuple<T...>>{get<0>(t)};
}
template<typename ...T, typename R>
__attribute__((always_inline))
constexpr decltype(auto) push_return_last(tuple<R, tuple<T...>> &&t) {
return tuple{get<1>(t), get<0>(t)};
}
template <typename Mode>
struct autodiff_apply {};
template <bool Mode>
struct autodiff_apply<ReverseMode<Mode>> {
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr decltype(auto) impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return push_return_last(__enzyme_autodiff<return_type>(f, ret_attr, args..., enzyme::get<I>(impl::forward<Tuple>(t))...));
}
};
template <>
struct autodiff_apply<ForwardMode> {
template <class return_type, class Tuple, std::size_t... I, typename... ExtraArgs>
__attribute__((always_inline))
static constexpr return_type impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence<I...>, ExtraArgs... args) {
return __enzyme_fwddiff<return_type>(f, ret_attr, args..., enzyme::get<I>(impl::forward<Tuple>(t))...);
}
};
template <typename function, class Tuple, std::size_t... I>
__attribute__((always_inline))
constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence<I...>) {
return f(enzyme::get<I>(impl::forward<Tuple>(t))...);
}
template < typename Mode, typename T >
struct default_ret_activity {
using type = Const<T>;
};
template <bool prim>
struct default_ret_activity<ReverseMode<prim>, float> {
using type = Active<float>;
};
template <bool prim>
struct default_ret_activity<ReverseMode<prim>, double> {
using type = Active<double>;
};
template<>
struct default_ret_activity<ForwardMode, float> {
using type = DuplicatedNoNeed<float>;
};
template<>
struct default_ret_activity<ForwardMode, double> {
using type = DuplicatedNoNeed<double>;
};
template < typename T >
struct ret_global;
template<typename T>
struct ret_global<Const<T>> {
static constexpr int* value = &enzyme_const_return;
};
template<typename T>
struct ret_global<Active<T>> {
static constexpr int* value = &enzyme_active_return;
};
template<typename T>
struct ret_global<Duplicated<T>> {
static constexpr int* value = &enzyme_dup_return;
};
template<typename T>
struct ret_global<DuplicatedNoNeed<T>> {
static constexpr int* value = &enzyme_dup_return;
};
template<typename Mode, typename RetAct>
struct ret_used;
template<typename RetAct>
struct ret_used<ReverseMode<true>, RetAct> {
static constexpr int* value = &enzyme_primal_return;
};
template<typename RetAct>
struct ret_used<ReverseMode<false>, RetAct> {
static constexpr int* value = &enzyme_noret;
};
template<typename T>
struct ret_used<ForwardMode, DuplicatedNoNeed<T>> {
static constexpr int* value = &enzyme_noret;
};
template<typename T>
struct ret_used<ForwardMode, Const<T>> {
static constexpr int* value = &enzyme_primal_return;
};
template<typename T>
struct ret_used<ForwardMode, Duplicated<T>> {
static constexpr int* value = &enzyme_primal_return;
};
template <typename T, typename... Args>
struct verify_dup_args {
static constexpr bool value = true;
};
template <typename... Args>
struct verify_dup_args<enzyme::Reverse, Args...> {
static constexpr bool value = (
(
!type_info<Args>::is_duplicated
|| std::is_reference_v<typename type_info<Args>::type>
|| std::is_pointer_v<typename type_info<Args>::type>
) && ...
);
};
template<class T>
struct is_function_or_function_ptr : std::is_function<T> {};
template<class T>
struct is_function_or_function_ptr<T*> : std::is_function<T> {};
} // namespace detail
template < typename return_type, typename function, typename ... enz_arg_types >
__attribute__((always_inline))
auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::primal_apply_impl<return_type>(std::move(f), impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}
template < typename function, typename ... arg_types>
auto primal_call(function && f, arg_types && ... args) {
return primal_impl<function>(impl::forward<function>(f), enzyme::tuple_cat(primal_args(args)...));
}
template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<!detail::is_function_or_function_ptr<typename remove_cvref< function >::type>::value, int> = 0>
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)&detail::templated_call<function, functy>::wrap, detail::ret_global<RetActivity>::value,
impl::forward<Tuple>(arg_tup),
std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{}, &enzyme_const, &f);
}
template < typename return_type, typename DiffMode, typename function, typename functy, typename RetActivity, typename ... enz_arg_types, std::enable_if_t<detail::is_function_or_function_ptr<typename remove_cvref< function >::type>::value, int> = 0>
__attribute__((always_inline))
auto autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) {
using Tuple = enzyme::tuple< enz_arg_types ... >;
return detail::autodiff_apply<DiffMode>::template impl<return_type>((void*)static_cast<functy*>(f), detail::ret_global<RetActivity>::value, impl::forward<Tuple>(arg_tup), std::make_index_sequence<enzyme::tuple_size_v<Tuple>>{});
}
template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
static_assert(
detail::verify_dup_args<DiffMode, arg_types...>::value,
"Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make "
"sense for Reverse mode AD"
);
using primal_return_type = decltype(f(arg_type(args)...));
using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_primals(args)..., enzyme::tuple{enzyme_interleave}, expand_shadows(args)...));
}
template < typename DiffMode, typename function, typename ... arg_types>
__attribute__((always_inline))
auto autodiff(function && f, arg_types && ... args) {
static_assert(
detail::verify_dup_args<DiffMode, arg_types...>::value,
"Non-reference/pointer Duplicated/DuplicatedNoNeed args don't make "
"sense for Reverse mode AD"
);
using primal_return_type = decltype(f(arg_type(args)...));
using functy = typename detail::function_type<primal_return_type, decltype(arg_type(args))...>::type;
using RetActivity = typename detail::default_ret_activity<DiffMode, primal_return_type>::type;
using return_type = typename autodiff_return<DiffMode, RetActivity, arg_types...>::type;
return autodiff_impl<return_type, DiffMode, function, functy, RetActivity>(impl::forward<function>(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used<DiffMode, RetActivity>::value}, expand_primals(args)..., enzyme::tuple{enzyme_interleave}, expand_shadows(args)...));
}
#pragma clang diagnostic pop
} // namespace enzyme