Rust adbench
Co-authored-by: Lorenz Schmidt <bytesnake@mailbox.org>
Co-authored-by: Jed Brown <jed@jedbrown.org>
Co-authored-by: William Moses <gh@wsmoses.com>
diff --git a/.gitignore b/.gitignore
index 5e7285d..62c76d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,6 +8,7 @@
enzyme/benchmarks/ReverseMode/*/*.bc
enzyme/benchmarks/ReverseMode/*/*.o
enzyme/benchmarks/ReverseMode/*/*.exe
+enzyme/benchmarks/ReverseMode/*/target/
enzyme/benchmarks/ReverseMode/*/results.txt
enzyme/benchmarks/ReverseMode/*/results.json
.cache
diff --git a/enzyme/benchmarks/ReverseMode/adbench/ba.h b/enzyme/benchmarks/ReverseMode/adbench/ba.h
index 3ade86a..6a3f977 100644
--- a/enzyme/benchmarks/ReverseMode/adbench/ba.h
+++ b/enzyme/benchmarks/ReverseMode/adbench/ba.h
@@ -115,60 +115,68 @@
};
extern "C" {
- void ba_objective(
- int n,
- int m,
- int p,
- double const* cams,
- double const* X,
- double const* w,
- int const* obs,
- double const* feats,
- double* reproj_err,
- double* w_err
- );
+void ba_objective_restrict(int n, int m, int p, double const *cams,
+ double const *X, double const *w, int const *obs,
+ double const *feats, double *reproj_err,
+ double *w_err);
- void dcompute_reproj_error(
- double const* cam,
- double * dcam,
- double const* X,
- double * dX,
- double const* w,
- double * wb,
- double const* feat,
- double *err,
- double *derr
- );
+void ba_objective(int n, int m, int p, double const *cams, double const *X,
+ double const *w, int const *obs, double const *feats,
+ double *reproj_err, double *w_err);
- void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr);
+void rust2_unsafe_ba_objective(int n, int m, int p, double const *cams,
+ double const *X, double const *w, int const *obs,
+ double const *feats, double *reproj_err,
+ double *w_err);
- void compute_reproj_error_b(
- double const* cam,
- double * dcam,
- double const* X,
- double * dX,
- double const* w,
- double * wb,
- double const* feat,
- double *err,
- double *derr
- );
+void rust2_ba_objective(int n, int m, int p, double const *cams,
+ double const *X, double const *w, int const *obs,
+ double const *feats, double *reproj_err, double *w_err);
- void compute_zach_weight_error_b(double const* w, double* dw, double* err, double* derr);
+void dcompute_reproj_error_restrict(double const *cam, double *dcam,
+ double const *X, double *dX,
+ double const *w, double *wb,
+ double const *feat, double *err,
+ double *derr);
- void adept_compute_reproj_error(
- double const* cam,
- double * dcam,
- double const* X,
- double * dX,
- double const* w,
- double * wb,
- double const* feat,
- double *err,
- double *derr
- );
+void dcompute_zach_weight_error_restrict(double const *w, double *dw,
+ double *err, double *derr);
- void adept_compute_zach_weight_error(double const* w, double* dw, double* err, double* derr);
+void dcompute_reproj_error(double const *cam, double *dcam, double const *X,
+ double *dX, double const *w, double *wb,
+ double const *feat, double *err, double *derr);
+
+void dcompute_zach_weight_error(double const *w, double *dw, double *err,
+ double *derr);
+
+void compute_reproj_error_b(double const *cam, double *dcam, double const *X,
+ double *dX, double const *w, double *wb,
+ double const *feat, double *err, double *derr);
+
+void compute_zach_weight_error_b(double const *w, double *dw, double *err,
+ double *derr);
+
+void adept_compute_reproj_error(double const *cam, double *dcam,
+ double const *X, double *dX, double const *w,
+ double *wb, double const *feat, double *err,
+ double *derr);
+
+void adept_compute_zach_weight_error(double const *w, double *dw, double *err,
+ double *derr);
+
+void rust_unsafe_dcompute_reproj_error(double const *cam, double *dcam,
+ double const *X, double *dX,
+ double const *w, double *wb,
+ double const *feat, double *err,
+ double *derr);
+
+void rust_dcompute_reproj_error(double const *cam, double *dcam,
+ double const *X, double *dX, double const *w,
+ double *wb, double const *feat, double *err,
+ double *derr);
+
+void rust_dcompute_zach_weight_error(double const *w, double *dw, double *err,
+ double *derr);
}
void read_ba_instance(const string& fn,
@@ -335,10 +343,22 @@
std::string path = "/mnt/Data/git/Enzyme/apps/ADBench/data/ba/ba1_n49_m7776_p31843.txt";
std::vector<std::string> paths = {
- "ba10_n1197_m126327_p563734.txt", "ba14_n356_m226730_p1255268.txt", "ba18_n1936_m649673_p5213733.txt", "ba2_n21_m11315_p36455.txt", "ba6_n539_m65220_p277273.txt", "test.txt",
- "ba11_n1723_m156502_p678718.txt", "ba15_n1102_m780462_p4052340.txt", "ba19_n4585_m1324582_p9125125.txt", "ba3_n161_m48126_p182072.txt", "ba7_n93_m61203_p287451.txt",
- "ba12_n253_m163691_p899155.txt", "ba16_n1544_m942409_p4750193.txt", "ba1_n49_m7776_p31843.txt", "ba4_n372_m47423_p204472.txt", "ba8_n88_m64298_p383937.txt",
- "ba13_n245_m198739_p1091386.txt", "ba17_n1778_m993923_p5001946.txt", "ba20_n13682_m4456117_p2987644.txt", "ba5_n257_m65132_p225911.txt", "ba9_n810_m88814_p393775.txt",
+ "ba10_n1197_m126327_p563734.txt",
+ "ba14_n356_m226730_p1255268.txt", // "ba18_n1936_m649673_p5213733.txt",
+ // "ba2_n21_m11315_p36455.txt",
+ // "ba6_n539_m65220_p277273.txt",
+ // "test.txt",
+ // "ba11_n1723_m156502_p678718.txt",
+ // "ba15_n1102_m780462_p4052340.txt",
+ // "ba19_n4585_m1324582_p9125125.txt",
+ // "ba3_n161_m48126_p182072.txt", "ba7_n93_m61203_p287451.txt",
+ // "ba12_n253_m163691_p899155.txt",
+ // "ba16_n1544_m942409_p4750193.txt", "ba1_n49_m7776_p31843.txt",
+ // "ba4_n372_m47423_p204472.txt", "ba8_n88_m64298_p383937.txt",
+ // "ba13_n245_m198739_p1091386.txt",
+ // "ba17_n1778_m993923_p5001946.txt",
+ // "ba20_n13682_m4456117_p2987644.txt",
+ // "ba5_n257_m65132_p225911.txt", "ba9_n810_m88814_p393775.txt",
};
std::ofstream jsonfile("results.json", std::ofstream::trunc);
@@ -358,27 +378,6 @@
BASparseMat(input.n, input.m, input.p)
};
- //BASparseMat(this->input.n, this->input.m, this->input.p)
-
- /*
- ba_objective(
- input.n,
- input.m,
- input.p,
- input.cams.data(),
- input.X.data(),
- input.w.data(),
- input.obs.data(),
- input.feats.data(),
- result.reproj_err.data(),
- result.w_err.data()
- );
-
- for(unsigned i=0; i<input.p; i++) {
- //printf("w_err[%d]=%f reproj_err[%d]=%f, reproj_err[%d]=%f\n", i, result.w_err[i], 2*i, result.reproj_err[2*i], 2*i+1, result.reproj_err[2*i+1]);
- }
- */
-
{
struct timeval start, end;
gettimeofday(&start, NULL);
@@ -409,10 +408,198 @@
BASparseMat(input.n, input.m, input.p)
};
- //BASparseMat(this->input.n, this->input.m, this->input.p)
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<adept_compute_reproj_error,
+ adept_compute_zach_weight_error>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Adept combined %0.6f\n", tdiff(&start, &end));
+ json adept;
+ adept["name"] = "Adept combined";
+ adept["runtime"] = tdiff(&start, &end);
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.J.vals[i]);
+ adept["result"].push_back(result.J.vals[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(adept);
+ }
+ }
- /*
- ba_objective(
+ {
+
+ struct BAInput input;
+ read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+ input.X, input.w, input.obs, input.feats);
+
+ struct BAOutput result = {std::vector<double>(2 * input.p),
+ std::vector<double>(input.p),
+ BASparseMat(input.n, input.m, input.p)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<dcompute_reproj_error_restrict,
+ dcompute_zach_weight_error_restrict>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme restrict c++ combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme restrict c++ combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.J.vals[i]);
+ enzyme["result"].push_back(result.J.vals[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ {
+
+ struct BAInput input;
+ read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+ input.X, input.w, input.obs, input.feats);
+
+ struct BAOutput result = {std::vector<double>(2 * input.p),
+ std::vector<double>(input.p),
+ BASparseMat(input.n, input.m, input.p)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<dcompute_reproj_error, dcompute_zach_weight_error>(
+ input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme aliasing c++ combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme c++ combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.J.vals[i]);
+ enzyme["result"].push_back(result.J.vals[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ {
+ struct BAInput input;
+ read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+ input.X, input.w, input.obs, input.feats);
+
+ struct BAOutput result = {std::vector<double>(2 * input.p),
+ std::vector<double>(input.p),
+ BASparseMat(input.n, input.m, input.p)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ ba_objective_restrict(input.n, input.m, input.p, input.cams.data(),
+ input.X.data(), input.w.data(), input.obs.data(),
+ input.feats.data(), result.reproj_err.data(),
+ result.w_err.data());
+ gettimeofday(&end, NULL);
+ printf("primal restrict c++ t=%0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "primal restrict c++";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.reproj_err[i]);
+ enzyme["result"].push_back(result.reproj_err[i]);
+ }
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.w_err[i]);
+ enzyme["result"].push_back(result.w_err[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ {
+ struct BAInput input;
+ read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+ input.X, input.w, input.obs, input.feats);
+
+ struct BAOutput result = {std::vector<double>(2 * input.p),
+ std::vector<double>(input.p),
+ BASparseMat(input.n, input.m, input.p)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ ba_objective(input.n, input.m, input.p, input.cams.data(), input.X.data(),
+ input.w.data(), input.obs.data(), input.feats.data(),
+ result.reproj_err.data(), result.w_err.data());
+ gettimeofday(&end, NULL);
+ printf("primal aliasing c++ t=%0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "primal aliasing c++";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for(unsigned i=0; i<5; i++) {
+ printf("%f ", result.reproj_err[i]);
+ enzyme["result"].push_back(result.reproj_err[i]);
+ }
+ for(unsigned i=0; i<5; i++) {
+ printf("%f ", result.w_err[i]);
+ enzyme["result"].push_back(result.w_err[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ {
+ struct BAInput input;
+ read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+ input.X, input.w, input.obs, input.feats);
+
+ struct BAOutput result = {std::vector<double>(2 * input.p),
+ std::vector<double>(input.p),
+ BASparseMat(input.n, input.m, input.p)};
+ {
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ rust2_unsafe_ba_objective(input.n, input.m, input.p, input.cams.data(),
+ input.X.data(), input.w.data(),
+ input.obs.data(), input.feats.data(),
+ result.reproj_err.data(), result.w_err.data());
+ gettimeofday(&end, NULL);
+ printf("primal unsafe rust t=%0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "primal unsafe rust";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.reproj_err[i]);
+ enzyme["result"].push_back(result.reproj_err[i]);
+ }
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.w_err[i]);
+ enzyme["result"].push_back(result.w_err[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ {
+ struct BAInput input;
+ read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats);
+
+ struct BAOutput result = {
+ std::vector<double>(2 * input.p),
+ std::vector<double>(input.p),
+ BASparseMat(input.n, input.m, input.p)
+ };
+ {
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ rust2_ba_objective(
input.n,
input.m,
input.p,
@@ -424,29 +611,22 @@
result.reproj_err.data(),
result.w_err.data()
);
-
- for(unsigned i=0; i<input.p; i++) {
- //printf("w_err[%d]=%f reproj_err[%d]=%f, reproj_err[%d]=%f\n", i, result.w_err[i], 2*i, result.reproj_err[2*i], 2*i+1, result.reproj_err[2*i+1]);
- }
- */
-
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
- calculate_jacobian<adept_compute_reproj_error, adept_compute_zach_weight_error>(input, result);
gettimeofday(&end, NULL);
- printf("Adept combined %0.6f\n", tdiff(&start, &end));
- json adept;
- adept["name"] = "Adept combined";
- adept["runtime"] = tdiff(&start, &end);
+ printf("primal rust t=%0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "primal rust";
+ enzyme["runtime"] = tdiff(&start, &end);
for(unsigned i=0; i<5; i++) {
- printf("%f ", result.J.vals[i]);
- adept["result"].push_back(result.J.vals[i]);
+ printf("%f ", result.reproj_err[i]);
+ enzyme["result"].push_back(result.reproj_err[i]);
+ }
+ for(unsigned i=0; i<5; i++) {
+ printf("%f ", result.w_err[i]);
+ enzyme["result"].push_back(result.w_err[i]);
}
printf("\n");
- test_suite["tools"].push_back(adept);
+ test_suite["tools"].push_back(enzyme);
}
-
}
{
@@ -460,35 +640,43 @@
BASparseMat(input.n, input.m, input.p)
};
- //BASparseMat(this->input.n, this->input.m, this->input.p)
-
- /*
- ba_objective(
- input.n,
- input.m,
- input.p,
- input.cams.data(),
- input.X.data(),
- input.w.data(),
- input.obs.data(),
- input.feats.data(),
- result.reproj_err.data(),
- result.w_err.data()
- );
-
- for(unsigned i=0; i<input.p; i++) {
- //printf("w_err[%d]=%f reproj_err[%d]=%f, reproj_err[%d]=%f\n", i, result.w_err[i], 2*i, result.reproj_err[2*i], 2*i+1, result.reproj_err[2*i+1]);
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<rust_unsafe_dcompute_reproj_error,
+ rust_dcompute_zach_weight_error>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme unsafe rust combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = 0; i < 5; i++) {
+ printf("%f ", result.J.vals[i]);
+ enzyme["result"].push_back(result.J.vals[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
}
- */
+ }
+
+ {
+
+ struct BAInput input;
+ read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+ input.X, input.w, input.obs, input.feats);
+
+ struct BAOutput result = {std::vector<double>(2 * input.p),
+ std::vector<double>(input.p),
+ BASparseMat(input.n, input.m, input.p)};
{
struct timeval start, end;
gettimeofday(&start, NULL);
- calculate_jacobian<dcompute_reproj_error, dcompute_zach_weight_error>(input, result);
+ calculate_jacobian<rust_dcompute_reproj_error, rust_dcompute_zach_weight_error>(input, result);
gettimeofday(&end, NULL);
- printf("Enzyme combined %0.6f\n", tdiff(&start, &end));
+ printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
- enzyme["name"] = "Enzyme combined";
+ enzyme["name"] = "Enzyme rust combined";
enzyme["runtime"] = tdiff(&start, &end);
for(unsigned i=0; i<5; i++) {
printf("%f ", result.J.vals[i]);
@@ -497,8 +685,8 @@
printf("\n");
test_suite["tools"].push_back(enzyme);
}
-
}
+
test_suite["llvm-version"] = __clang_version__;
test_suite["mode"] = "ReverseMode";
test_suite["batch-size"] = 1;
diff --git a/enzyme/benchmarks/ReverseMode/adbench/gmm.h b/enzyme/benchmarks/ReverseMode/adbench/gmm.h
index feef3a7..c5ec727 100644
--- a/enzyme/benchmarks/ReverseMode/adbench/gmm.h
+++ b/enzyme/benchmarks/ReverseMode/adbench/gmm.h
@@ -18,7 +18,7 @@
using json = nlohmann::json;
struct GMMInput {
- int d, k, n;
+ size_t d, k, n;
std::vector<double> alphas, means, icf, x;
Wishart wishart;
};
@@ -33,24 +33,54 @@
};
extern "C" {
- void dgmm_objective(int d, int k, int n, const double *alphas, double *
- alphasb, const double *means, double *meansb, const double *icf,
- double *icfb, const double *x, Wishart wishart, double *err, double *
- errb);
+void gmm_objective(size_t d, size_t k, size_t n, double const *alphas,
+ double const *means, double const *icf, double const *x,
+ Wishart wishart, double *err);
+void gmm_objective_restrict(size_t d, size_t k, size_t n, double const *alphas,
+ double const *means, double const *icf,
+ double const *x, Wishart wishart, double *err);
+void dgmm_objective_restrict(size_t d, size_t k, size_t n, const double *alphas,
+ double *alphasb, const double *means,
+ double *meansb, const double *icf, double *icfb,
+ const double *x, Wishart wishart, double *err,
+ double *errb);
+void dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double *alphasb,
+ const double *means, double *meansb, const double *icf,
+ double *icfb, const double *x, Wishart wishart, double *err,
+ double *errb);
- void gmm_objective_b(int d, int k, int n, const double *alphas, double *
- alphasb, const double *means, double *meansb, const double *icf,
- double *icfb, const double *x, Wishart wishart, double *err, double *
- errb);
+void gmm_objective_b(size_t d, size_t k, size_t n, const double *alphas, double *alphasb,
+ const double *means, double *meansb, const double *icf,
+ double *icfb, const double *x, Wishart wishart,
+ double *err, double *errb);
- void adept_dgmm_objective(int d, int k, int n, const double *alphas, double *
- alphasb, const double *means, double *meansb, const double *icf,
- double *icfb, const double *x, Wishart wishart, double *err, double *
- errb);
+void adept_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+ double *alphasb, const double *means, double *meansb,
+ const double *icf, double *icfb, const double *x,
+ Wishart wishart, double *err, double *errb);
+
+void rust_unsafe_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+ double *alphasb, const double *means,
+ double *meansb, const double *icf, double *icfb,
+ const double *x, Wishart &wishart, double *err,
+ double *errb);
+
+void rust_unsafe_gmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+ const double *means, const double *icf,
+ const double *x, Wishart &wishart, double *err);
+
+void rust_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+ double *alphasb, const double *means, double *meansb,
+ const double *icf, double *icfb, const double *x,
+ Wishart &wishart, double *err, double *errb);
+
+void rust_gmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+ const double *means, const double *icf, const double *x,
+ Wishart &wishart, double *err);
}
void read_gmm_instance(const string& fn,
- int* d, int* k, int* n,
+ size_t* d, size_t* k, size_t* n,
vector<double>& alphas,
vector<double>& means,
vector<double>& icf,
@@ -65,32 +95,32 @@
exit(1);
}
- fscanf(fid, "%i %i %i", d, k, n);
+ fscanf(fid, "%zu %zu %zu", d, k, n);
- int d_ = *d, k_ = *k, n_ = *n;
+ size_t d_ = *d, k_ = *k, n_ = *n;
- int icf_sz = d_ * (d_ + 1) / 2;
+ size_t icf_sz = d_ * (d_ + 1) / 2;
alphas.resize(k_);
means.resize(d_ * k_);
icf.resize(icf_sz * k_);
x.resize(d_ * n_);
- for (int i = 0; i < k_; i++)
+ for (size_t i = 0; i < k_; i++)
{
fscanf(fid, "%lf", &alphas[i]);
}
- for (int i = 0; i < k_; i++)
+ for (size_t i = 0; i < k_; i++)
{
- for (int j = 0; j < d_; j++)
+ for (size_t j = 0; j < d_; j++)
{
fscanf(fid, "%lf", &means[i * d_ + j]);
}
}
- for (int i = 0; i < k_; i++)
+ for (size_t i = 0; i < k_; i++)
{
- for (int j = 0; j < icf_sz; j++)
+ for (size_t j = 0; j < icf_sz; j++)
{
fscanf(fid, "%lf", &icf[i * icf_sz + j]);
}
@@ -98,20 +128,20 @@
if (replicate_point)
{
- for (int j = 0; j < d_; j++)
+ for (size_t j = 0; j < d_; j++)
{
fscanf(fid, "%lf", &x[j]);
}
- for (int i = 0; i < n_; i++)
+ for (size_t i = 0; i < n_; i++)
{
memcpy(&x[i * d_], &x[0], d_ * sizeof(double));
}
}
else
{
- for (int i = 0; i < n_; i++)
+ for (size_t i = 0; i < n_; i++)
{
- for (int j = 0; j < d_; j++)
+ for (size_t j = 0; j < d_; j++)
{
fscanf(fid, "%lf", &x[i * d_ + j]);
}
@@ -123,10 +153,7 @@
fclose(fid);
}
-typedef void(*deriv_t)(int d, int k, int n, const double *alphas, double *alphasb, const double *means, double *meansb, const double *icf,
- double *icfb, const double *x, Wishart wishart, double *err, double *errb);
-
-template<deriv_t deriv>
+template<auto deriv>
void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result)
{
double* alphas_gradient_part = result.gradient.data();
@@ -159,13 +186,38 @@
);
}
+template<auto deriv>
+double primal(struct GMMInput &input)
+{
+ double tmp = 0.0; // stores fictive result
+ // (Tapenade doesn't calculate an original function in reverse mode)
+ deriv(
+ input.d,
+ input.k,
+ input.n,
+ input.alphas.data(),
+ input.means.data(),
+ input.icf.data(),
+ input.x.data(),
+ input.wishart,
+ &tmp
+ );
+ return tmp;
+}
+
int main(const int argc, const char* argv[]) {
printf("starting main\n");
const auto replicate_point = (argc > 9 && string(argv[9]) == "-rep");
const GMMParameters params = { replicate_point };
- std::vector<std::string> paths;// = { "1k/gmm_d10_K100.txt" };
+ std::vector<std::string> paths = { "10k/gmm_d10_K200.txt" };
+
+ //getTests(paths, "data/1k", "1k/");
+ if (std::getenv("BENCH_LARGE")) {
+ getTests(paths, "data/2.5k", "2.5k/");
+ getTests(paths, "data/10k", "10k/");
+ }
getTests(paths, "data/1k", "1k/");
if (std::getenv("BENCH_LARGE")) {
@@ -188,7 +240,7 @@
read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
- int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+ size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
struct GMMOutput result = { 0, std::vector<double>(Jcols) };
@@ -218,49 +270,82 @@
read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
- int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+ size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
struct GMMOutput result = { 0, std::vector<double>(Jcols) };
- try {
- struct timeval start, end;
- gettimeofday(&start, NULL);
- calculate_jacobian<adept_dgmm_objective>(input, result);
- gettimeofday(&end, NULL);
- printf("Adept combined %0.6f\n", tdiff(&start, &end));
- json adept;
- adept["name"] = "Adept combined";
- adept["runtime"] = tdiff(&start, &end);
- for (unsigned i = result.gradient.size() - 5;
- i < result.gradient.size(); i++) {
- printf("%f ", result.gradient[i]);
- adept["result"].push_back(result.gradient[i]);
+ //if (0) {
+ try {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<adept_dgmm_objective>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Adept combined %0.6f\n", tdiff(&start, &end));
+ json adept;
+ adept["name"] = "Adept combined";
+ adept["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5;
+ i < result.gradient.size(); i++) {
+ printf("%f ", result.gradient[i]);
+ adept["result"].push_back(result.gradient[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(adept);
+ } catch (std::bad_alloc) {
+ printf("Adept combined 88888888 ooms\n");
}
- printf("\n");
- test_suite["tools"].push_back(adept);
- } catch(std::bad_alloc) {
- printf("Adept combined 88888888 ooms\n");
+ //}
}
- }
-
+ for (size_t i = 0; i < 5; i++)
{
struct GMMInput input;
read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
- int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+ size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
struct GMMOutput result = { 0, std::vector<double>(Jcols) };
{
struct timeval start, end;
gettimeofday(&start, NULL);
+ calculate_jacobian<dgmm_objective_restrict>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme c++ restrict combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme restrict combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+ i++) {
+ printf("%f ", result.gradient[i]);
+ enzyme["result"].push_back(result.gradient[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ {
+
+ struct GMMInput input;
+ read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+ input.alphas, input.means, input.icf, input.x,
+ input.wishart, params.replicate_point);
+
+ size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+ struct GMMOutput result = {0, std::vector<double>(Jcols)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
calculate_jacobian<dgmm_objective>(input, result);
gettimeofday(&end, NULL);
+ printf("Enzyme c++ mayalias combined %0.6f\n", tdiff(&start, &end));
json enzyme;
- enzyme["name"] = "Enzyme combined";
+ enzyme["name"] = "Enzyme mayalias combined";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5;
i < result.gradient.size(); i++) {
@@ -270,8 +355,132 @@
printf("\n");
test_suite["tools"].push_back(enzyme);
}
-
}
+
+ {
+
+ struct GMMInput input;
+ read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+ input.alphas, input.means, input.icf, input.x,
+ input.wishart, params.replicate_point);
+
+ size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+ struct GMMOutput result = {0, std::vector<double>(Jcols)};
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<rust_unsafe_dgmm_objective>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Rust unsafe Enzyme combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+ i++) {
+ printf("%f ", result.gradient[i]);
+ enzyme["result"].push_back(result.gradient[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ for (size_t i = 0; i < 5; i++)
+ {
+
+ struct GMMInput input;
+ read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+ input.alphas, input.means, input.icf, input.x,
+ input.wishart, params.replicate_point);
+
+ size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+ struct GMMOutput result = {0, std::vector<double>(Jcols)};
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<rust_dgmm_objective>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Rust Enzyme combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5;
+ i < result.gradient.size(); i++) {
+ printf("%f ", result.gradient[i]);
+ enzyme["result"].push_back(result.gradient[i]);
+ }
+ printf("\n");
+ test_suite["tools"].push_back(enzyme);
+ }
+ }
+
+ {
+
+ struct GMMInput input;
+ read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+ input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
+
+ size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+ struct GMMOutput result = { 0, std::vector<double>(Jcols) };
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ auto res = primal<gmm_objective>(input);
+ gettimeofday(&end, NULL);
+ printf("c++ primal mayalias combined t=%0.6f, err=%f\n",
+ tdiff(&start, &end), res);
+
+ json primal;
+ primal["name"] = "C++ primal mayalias";
+ primal["runtime"] = tdiff(&start, &end);
+ primal["result"].push_back(res);
+ test_suite["tools"].push_back(primal);
+ }
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ auto res = primal<gmm_objective_restrict>(input);
+ gettimeofday(&end, NULL);
+ printf("c++ primal restrict combined t=%0.6f, err=%f\n",
+ tdiff(&start, &end), res);
+
+ json primal;
+ primal["name"] = "C++ primal restrict";
+ primal["runtime"] = tdiff(&start, &end);
+ primal["result"].push_back(res);
+ test_suite["tools"].push_back(primal);
+ }
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ auto res = primal<rust_unsafe_gmm_objective>(input);
+ gettimeofday(&end, NULL);
+ printf("rust unsafe primal combined t=%0.6f, err=%f\n",
+ tdiff(&start, &end), res);
+ json primal;
+ primal["name"] = "Rust unsafe primal";
+ primal["runtime"] = tdiff(&start, &end);
+ primal["result"].push_back(res);
+ test_suite["tools"].push_back(primal);
+ }
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ auto res = primal<rust_gmm_objective>(input);
+ gettimeofday(&end, NULL);
+ printf("rust primal combined t=%0.6f, err=%f\n", tdiff(&start, &end), res);
+ json primal;
+ primal["name"] = "Rust primal";
+ primal["runtime"] = tdiff(&start, &end);
+ primal["result"].push_back(res);
+ test_suite["tools"].push_back(primal);
+ }
+ }
+
test_suite["llvm-version"] = __clang_version__;
test_suite["mode"] = "ReverseMode";
test_suite["batch-size"] = 1;
diff --git a/enzyme/benchmarks/ReverseMode/adbench/lstm.h b/enzyme/benchmarks/ReverseMode/adbench/lstm.h
index e6d1330..4f99841 100644
--- a/enzyme/benchmarks/ReverseMode/adbench/lstm.h
+++ b/enzyme/benchmarks/ReverseMode/adbench/lstm.h
@@ -34,37 +34,56 @@
};
extern "C" {
- void dlstm_objective(
- int l,
- int c,
- int b,
- double const* main_params,
- double* dmain_params,
- double const* extra_params,
- double* dextra_params,
- double* state,
- double const* sequence,
- double* loss,
- double* dloss
- );
+void rust_unsafe_dlstm_objective(int l, int c, int b, double const *main_params,
+ double *dmain_params,
+ double const *extra_params,
+ double *dextra_params, double *state,
+ double const *sequence, double *loss,
+ double *dloss);
- void lstm_objective_b(int l, int c, int b, const double *main_params, double *
- main_paramsb, const double *extra_params, double *extra_paramsb,
- double *state, const double *sequence, double *loss, double *lossb);
+void rust_unsafe_lstm_objective(int l, int c, int b, double const *main_params,
+ double const *extra_params, double *state,
+ double const *sequence, double *loss);
- void adept_dlstm_objective(
- int l,
- int c,
- int b,
- double const* main_params,
- double* dmain_params,
- double const* extra_params,
- double* dextra_params,
- double* state,
- double const* sequence,
- double* loss,
- double* dloss
- );
+void rust_safe_lstm_objective(int l, int c, int b, double const *main_params,
+ double const *extra_params, double *state,
+ double const *sequence, double *loss);
+
+void cxx_restrict_lstm_objective(int l, int c, int b, double const *main_params,
+ double const *extra_params, double *state,
+ double const *sequence, double *loss);
+
+void cxx_mayalias_lstm_objective(int l, int c, int b, double const *main_params,
+ double const *extra_params, double *state,
+ double const *sequence, double *loss);
+
+void rust_safe_dlstm_objective(int l, int c, int b, double const *main_params,
+ double *dmain_params, double const *extra_params,
+ double *dextra_params, double *state,
+ double const *sequence, double *loss,
+ double *dloss);
+
+void dlstm_objective_mayalias(int l, int c, int b, double const *main_params,
+ double *dmain_params, double const *extra_params,
+ double *dextra_params, double *state,
+ double const *sequence, double *loss,
+ double *dloss);
+
+void dlstm_objective_restrict(int l, int c, int b, double const *main_params,
+ double *dmain_params, double const *extra_params,
+ double *dextra_params, double *state,
+ double const *sequence, double *loss,
+ double *dloss);
+
+void lstm_objective_b(int l, int c, int b, const double *main_params,
+ double *main_paramsb, const double *extra_params,
+ double *extra_paramsb, double *state,
+ const double *sequence, double *loss, double *lossb);
+
+void adept_dlstm_objective(int l, int c, int b, double const *main_params,
+ double *dmain_params, double const *extra_params,
+ double *dextra_params, double *state,
+ double const *sequence, double *loss, double *dloss);
}
void read_lstm_instance(const string& fn,
@@ -177,10 +196,55 @@
}
}
+double calculate_mayalias_primal(struct LSTMInput &input) {
+ double loss = 0.0;
+ for (int i = 0; i < 100; i++) {
+ cxx_mayalias_lstm_objective(
+ input.l, input.c, input.b, input.main_params.data(),
+ input.extra_params.data(), input.state.data(),
+ input.sequence.data(), &loss);
+ }
+ return loss;
+}
+
+double calculate_restrict_primal(struct LSTMInput &input) {
+ double loss = 0.0;
+ for (int i = 0; i < 100; i++) {
+ cxx_restrict_lstm_objective(
+ input.l, input.c, input.b, input.main_params.data(),
+ input.extra_params.data(), input.state.data(),
+ input.sequence.data(), &loss);
+ }
+ return loss;
+}
+
+double calculate_unsafe_primal(struct LSTMInput &input) {
+ double loss = 0.0;
+ for (int i = 0; i < 100; i++) {
+ rust_unsafe_lstm_objective(
+ input.l, input.c, input.b, input.main_params.data(),
+ input.extra_params.data(), input.state.data(),
+ input.sequence.data(), &loss);
+ }
+ return loss;
+}
+
+double calculate_safe_primal(struct LSTMInput &input) {
+ double loss = 0.0;
+ for (int i = 0; i < 100; i++) {
+ rust_safe_lstm_objective(input.l, input.c, input.b,
+ input.main_params.data(),
+ input.extra_params.data(), input.state.data(),
+ input.sequence.data(), &loss);
+ }
+ return loss;
+}
+
int main(const int argc, const char* argv[]) {
printf("starting main\n");
- std::vector<std::string> paths = { "lstm_l2_c1024.txt", "lstm_l4_c1024.txt", "lstm_l2_c4096.txt", "lstm_l4_c4096.txt" };
+ //std::vector<std::string> paths = { "lstm_l2_c1024.txt", "lstm_l4_c1024.txt", "lstm_l2_c4096.txt", "lstm_l4_c4096.txt" };
+ std::vector<std::string> paths = { "lstm_l4_c4096.txt" };
std::ofstream jsonfile("results.json", std::ofstream::trunc);
json test_results;
@@ -227,16 +291,17 @@
{
- struct LSTMInput input = {};
+ struct LSTMInput input = {};
// Read instance
- read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, input.main_params, input.extra_params, input.state,
- input.sequence);
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+ input.main_params, input.extra_params, input.state,
+ input.sequence);
- std::vector<double> state = std::vector<double>(input.state.size());
+ std::vector<double> state = std::vector<double>(input.state.size());
- int Jcols = 8 * input.l * input.b + 3 * input.b;
- struct LSTMOutput result = { 0, std::vector<double>(Jcols) };
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = { 0, std::vector<double>(Jcols) };
{
struct timeval start, end;
@@ -274,14 +339,81 @@
{
struct timeval start, end;
gettimeofday(&start, NULL);
- calculate_jacobian<dlstm_objective>(input, result);
+ calculate_jacobian<dlstm_objective_restrict>(input, result);
gettimeofday(&end, NULL);
- printf("Enzyme combined %0.6f\n", tdiff(&start, &end));
+ printf("Enzyme restrict combined %0.6f\n", tdiff(&start, &end));
json enzyme;
- enzyme["name"] = "Enzyme combined";
- enzyme["runtime"] = tdiff(&start, &end);
- for (unsigned i = result.gradient.size() - 5;
- i < result.gradient.size(); i++) {
+ enzyme["name"] = "Enzyme restrict combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+ i++) {
+ printf("%f ", result.gradient[i]);
+ enzyme["result"].push_back(result.gradient[i]);
+ }
+ test_suite["tools"].push_back(enzyme);
+
+ printf("\n");
+ }
+ }
+
+ {
+
+ struct LSTMInput input = {};
+
+ // Read instance
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+ input.main_params, input.extra_params, input.state,
+ input.sequence);
+
+ std::vector<double> state = std::vector<double>(input.state.size());
+
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<dlstm_objective_mayalias>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme mayalias combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme mayalias combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+ i++) {
+ printf("%f ", result.gradient[i]);
+ enzyme["result"].push_back(result.gradient[i]);
+ }
+ test_suite["tools"].push_back(enzyme);
+
+ printf("\n");
+ }
+ }
+
+ {
+
+ struct LSTMInput input = {};
+
+ // Read instance
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, input.main_params, input.extra_params, input.state,
+ input.sequence);
+
+ std::vector<double> state = std::vector<double>(input.state.size());
+
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = { 0, std::vector<double>(Jcols) };
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<rust_safe_dlstm_objective>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme (safe Rust) combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme (safe Rust) combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+ i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
@@ -291,6 +423,161 @@
}
}
+
+ {
+
+ struct LSTMInput input = {};
+
+ // Read instance
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+ input.main_params, input.extra_params, input.state,
+ input.sequence);
+
+ std::vector<double> state = std::vector<double>(input.state.size());
+
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ calculate_jacobian<rust_unsafe_dlstm_objective>(input, result);
+ gettimeofday(&end, NULL);
+ printf("Enzyme (unsafe Rust) combined %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme (unsafe Rust) combined";
+ enzyme["runtime"] = tdiff(&start, &end);
+ for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+ i++) {
+ printf("%f ", result.gradient[i]);
+ enzyme["result"].push_back(result.gradient[i]);
+ }
+ test_suite["tools"].push_back(enzyme);
+
+ printf("\n");
+ }
+ }
+ {
+
+ struct LSTMInput input = {};
+
+ // Read instance
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+ input.main_params, input.extra_params, input.state,
+ input.sequence);
+
+ std::vector<double> state = std::vector<double>(input.state.size());
+
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ double res = calculate_mayalias_primal(input);
+ gettimeofday(&end, NULL);
+ printf("C++ mayalias primal %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "C++ mayalias primal";
+ enzyme["runtime"] = tdiff(&start, &end);
+ printf("%f ", res);
+ enzyme["result"].push_back(res);
+ test_suite["tools"].push_back(enzyme);
+
+ printf("\n");
+ }
+ }
+ {
+
+ struct LSTMInput input = {};
+
+ // Read instance
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+ input.main_params, input.extra_params, input.state,
+ input.sequence);
+
+ std::vector<double> state = std::vector<double>(input.state.size());
+
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ double res = calculate_restrict_primal(input);
+ gettimeofday(&end, NULL);
+ printf("C++ restrict primal %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "C++ restrict primal";
+ enzyme["runtime"] = tdiff(&start, &end);
+ printf("%f ", res);
+ enzyme["result"].push_back(res);
+ test_suite["tools"].push_back(enzyme);
+
+ printf("\n");
+ }
+ }
+ {
+
+ struct LSTMInput input = {};
+
+ // Read instance
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+ input.main_params, input.extra_params, input.state,
+ input.sequence);
+
+ std::vector<double> state = std::vector<double>(input.state.size());
+
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ double res =calculate_unsafe_primal(input);
+ gettimeofday(&end, NULL);
+ printf("Enzyme (unsafe Rust) primal %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme (unsafe Rust) primal";
+ enzyme["runtime"] = tdiff(&start, &end);
+ printf("%f ", res);
+ enzyme["result"].push_back(res);
+ test_suite["tools"].push_back(enzyme);
+
+ printf("\n");
+ }
+ }
+ {
+
+ struct LSTMInput input = {};
+
+ // Read instance
+ read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+ input.main_params, input.extra_params, input.state,
+ input.sequence);
+
+ std::vector<double> state = std::vector<double>(input.state.size());
+
+ int Jcols = 8 * input.l * input.b + 3 * input.b;
+ struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ double res = calculate_safe_primal(input);
+ gettimeofday(&end, NULL);
+ printf("Enzyme (safe Rust) primal %0.6f\n", tdiff(&start, &end));
+ json enzyme;
+ enzyme["name"] = "Enzyme (safe Rust) primal";
+ enzyme["runtime"] = tdiff(&start, &end);
+ printf("%f ", res);
+ enzyme["result"].push_back(res);
+ test_suite["tools"].push_back(enzyme);
+
+ printf("\n");
+ }
+ }
+
test_suite["llvm-version"] = __clang_version__;
test_suite["mode"] = "ReverseMode";
test_suite["batch-size"] = 1;
diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.lock b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock
new file mode 100644
index 0000000..74e2768
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock
@@ -0,0 +1,16 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "bars"
+version = "0.1.0"
+dependencies = [
+ "libm",
+]
+
+[[package]]
+name = "libm"
+version = "0.2.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.toml b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml
new file mode 100644
index 0000000..4bc9c21
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml
@@ -0,0 +1,23 @@
+[package]
+name = "bars"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+crate-type = ["cdylib"]
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
+
+[dependencies]
+libm = { version = "0.2.8", optional = true }
diff --git a/enzyme/benchmarks/ReverseMode/ba/Makefile.make b/enzyme/benchmarks/ReverseMode/ba/Makefile.make
index b7f013d..50ab0cf 100644
--- a/enzyme/benchmarks/ReverseMode/ba/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/ba/Makefile.make
@@ -6,6 +6,10 @@
clean:
rm -f *.ll *.o results.txt results.json
+ cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a: src/lib.rs Cargo.toml
+ RUSTFLAGS="-Z autodiff=Enable" cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm
%-unopt.ll: %.cpp
clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,8 +20,8 @@
%-opt.ll: %-raw.ll
opt $^ -o $@ -S
-ba.o: ba-opt.ll
+ba.o: ba-opt.ll $(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a
clang++ $(BENCH) -pthread -O2 $^ -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o $@ $(BENCHLINK) -lpthread -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
results.json: ba.o
- ./$^
+ numactl -C 1 ./$^
diff --git a/enzyme/benchmarks/ReverseMode/ba/ba.cpp b/enzyme/benchmarks/ReverseMode/ba/ba.cpp
index b71e05a..602af73 100644
--- a/enzyme/benchmarks/ReverseMode/ba/ba.cpp
+++ b/enzyme/benchmarks/ReverseMode/ba/ba.cpp
@@ -43,17 +43,13 @@
return res;
}
-
-
-void cross(double const* a, double const* b, double* out)
-{
+void cross_restrict(double const *__restrict a, double const *__restrict b,
+ double *__restrict out) {
out[0] = a[1] * b[2] - a[2] * b[1];
out[1] = a[2] * b[0] - a[0] * b[2];
out[2] = a[0] * b[1] - a[1] * b[0];
}
-
-
/* ===================================================================== */
/* MAIN LOGIC */
/* ===================================================================== */
@@ -68,8 +64,9 @@
// n = w / theta;
// n_x = au_cross_matrix(n);
// R = eye(3) + n_x*sin(theta) + n_x*n_x*(1 - cos(theta));
-void rodrigues_rotate_point(double const* __restrict rot, double const* __restrict pt, double *__restrict rotatedPt)
-{
+void rodrigues_rotate_point_restrict(double const *__restrict rot,
+ double const *__restrict pt,
+ double *__restrict rotatedPt) {
int i;
double sqtheta = sqsum(3, rot);
if (sqtheta != 0)
@@ -87,7 +84,7 @@
w[i] = rot[i] * theta_inverse;
}
- cross(w, pt, w_cross_pt);
+ cross_restrict(w, pt, w_cross_pt);
tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) *
(1. - costheta);
@@ -100,7 +97,7 @@
else
{
double rot_cross_pt[3];
- cross(rot, pt, rot_cross_pt);
+ cross_restrict(rot, pt, rot_cross_pt);
for (i = 0; i < 3; i++)
{
@@ -109,8 +106,6 @@
}
}
-
-
void radial_distort(double const* rad_params, double *proj)
{
double rsq, L;
@@ -120,10 +115,8 @@
proj[1] = proj[1] * L;
}
-
-
-void project(double const* __restrict cam, double const* __restrict X, double* __restrict proj)
-{
+void project_restrict(double const *__restrict cam, double const *__restrict X,
+ double *__restrict proj) {
double const* C = &cam[3];
double Xo[3], Xcam[3];
@@ -131,7 +124,7 @@
Xo[1] = X[1] - C[1];
Xo[2] = X[2] - C[2];
- rodrigues_rotate_point(&cam[0], Xo, Xcam);
+ rodrigues_rotate_point_restrict(&cam[0], Xo, Xcam);
proj[0] = Xcam[0] / Xcam[2];
proj[1] = Xcam[1] / Xcam[2];
@@ -142,8 +135,6 @@
proj[1] = proj[1] * cam[6] + cam[8];
}
-
-
// cam: 11 camera in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
// r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
// [C1 C2 C3]' is the camera center
@@ -158,30 +149,23 @@
// distorted = radial_distort(projective2euclidean(Xcam), radial_parameters)
// proj = distorted * f + principal_point
// err = sqsum(proj - measurement)
-void compute_reproj_error(
- double const* __restrict cam,
- double const* __restrict X,
- double const* __restrict w,
- double const* __restrict feat,
- double * __restrict err
-)
-{
+void compute_reproj_error_restrict(double const *__restrict cam,
+ double const *__restrict X,
+ double const *__restrict w,
+ double const *__restrict feat,
+ double *__restrict err) {
double proj[2];
- project(cam, X, proj);
+ project_restrict(cam, X, proj);
err[0] = (*w)*(proj[0] - feat[0]);
err[1] = (*w)*(proj[1] - feat[1]);
}
-
-
-void compute_zach_weight_error(double const* w, double* err)
-{
+void compute_zach_weight_error_restrict(double const *__restrict w,
+ double *__restrict err) {
*err = 1 - (*w)*(*w);
}
-
-
// n number of cameras
// m number of points
// p number of observations
@@ -196,36 +180,23 @@
// feats: 2*p features (x,y coordinates corresponding to observations)
// reproj_err: 2*p errors of observations
// w_err: p weight "error" terms
-void ba_objective(
- int n,
- int m,
- int p,
- double const* cams,
- double const* X,
- double const* w,
- int const* obs,
- double const* feats,
- double* reproj_err,
- double* w_err
-)
-{
+void ba_objective_restrict(int n, int m, int p, double const *cams,
+ double const *X, double const *w, int const *obs,
+ double const *feats, double *reproj_err,
+ double *w_err) {
int i;
for (i = 0; i < p; i++)
{
int camIdx = obs[i * 2 + 0];
int ptIdx = obs[i * 2 + 1];
- compute_reproj_error(
- &cams[camIdx * BA_NCAMPARAMS],
- &X[ptIdx * 3],
- &w[i],
- &feats[i * 2],
- &reproj_err[2 * i]
- );
+ compute_reproj_error_restrict(&cams[camIdx * BA_NCAMPARAMS],
+ &X[ptIdx * 3], &w[i], &feats[i * 2],
+ &reproj_err[2 * i]);
}
for (i = 0; i < p; i++)
{
- compute_zach_weight_error(&w[i], &w_err[i]);
+ compute_zach_weight_error_restrict(&w[i], &w_err[i]);
}
}
@@ -234,32 +205,21 @@
extern int enzyme_dupnoneed;
void __enzyme_autodiff(...) noexcept;
-void dcompute_reproj_error(
- double const* cam,
- double * dcam,
- double const* X,
- double * dX,
- double const* w,
- double * wb,
- double const* feat,
- double *err,
- double *derr
-)
-{
- __enzyme_autodiff(compute_reproj_error,
- enzyme_dup, cam, dcam,
- enzyme_dup, X, dX,
- enzyme_dup, w, wb,
- enzyme_const, feat,
- enzyme_dupnoneed, err, derr);
+void dcompute_reproj_error_restrict(double const *cam, double *dcam,
+ double const *X, double *dX,
+ double const *w, double *wb,
+ double const *feat, double *err,
+ double *derr) {
+ __enzyme_autodiff(compute_reproj_error_restrict, enzyme_dup, cam, dcam,
+ enzyme_dup, X, dX, enzyme_dup, w, wb, enzyme_const, feat,
+ enzyme_dupnoneed, err, derr);
}
-void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr) {
- __enzyme_autodiff(compute_zach_weight_error,
- enzyme_dup, w, dw,
- enzyme_dupnoneed, err, derr);
+void dcompute_zach_weight_error_restrict(double const *w, double *dw,
+ double *err, double *derr) {
+ __enzyme_autodiff(compute_zach_weight_error_restrict, enzyme_dup, w, dw,
+ enzyme_dupnoneed, err, derr);
}
-
}
@@ -911,3 +871,5 @@
*dw = aw.get_gradient();
}
+
+#include "ba_mayalias.h"
diff --git a/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h b/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h
new file mode 100644
index 0000000..25197b5
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h
@@ -0,0 +1,198 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+extern "C" {
+
+/* ===================================================================== */
+/* UTILS */
+/* ===================================================================== */
+
+void cross(double const *a, double const *b, double *out) {
+ out[0] = a[1] * b[2] - a[2] * b[1];
+ out[1] = a[2] * b[0] - a[0] * b[2];
+ out[2] = a[0] * b[1] - a[1] * b[0];
+}
+
+/* ===================================================================== */
+/* MAIN LOGIC */
+/* ===================================================================== */
+
+void compute_zach_weight_error(double const *w, double *err) {
+ *err = 1 - (*w) * (*w);
+}
+
+// rot: 3 rotation parameters
+// pt: 3 point to be rotated
+// rotatedPt: 3 rotated point
+// this is an efficient evaluation (part of
+// the Ceres implementation)
+// easy to understand calculation in matlab:
+// theta = sqrt(sum(w. ^ 2));
+// n = w / theta;
+// n_x = au_cross_matrix(n);
+// R = eye(3) + n_x*sin(theta) + n_x*n_x*(1 - cos(theta));
+void rodrigues_rotate_point(double const *rot, double const *pt,
+ double *rotatedPt) {
+ int i;
+ double sqtheta = sqsum(3, rot);
+ if (sqtheta != 0)
+ {
+ double theta, costheta, sintheta, theta_inverse;
+ double w[3], w_cross_pt[3], tmp;
+
+ theta = sqrt(sqtheta);
+ costheta = cos(theta);
+ sintheta = sin(theta);
+ theta_inverse = 1.0 / theta;
+
+ for (i = 0; i < 3; i++)
+ {
+ w[i] = rot[i] * theta_inverse;
+ }
+
+ cross(w, pt, w_cross_pt);
+
+ tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) *
+ (1. - costheta);
+
+ for (i = 0; i < 3; i++)
+ {
+ rotatedPt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
+ }
+ }
+ else
+ {
+ double rot_cross_pt[3];
+ cross(rot, pt, rot_cross_pt);
+
+ for (i = 0; i < 3; i++)
+ {
+ rotatedPt[i] = pt[i] + rot_cross_pt[i];
+ }
+ }
+}
+
+void project(double const *cam, double const *X, double *proj) {
+ double const* C = &cam[3];
+ double Xo[3], Xcam[3];
+
+ Xo[0] = X[0] - C[0];
+ Xo[1] = X[1] - C[1];
+ Xo[2] = X[2] - C[2];
+
+ rodrigues_rotate_point(&cam[0], Xo, Xcam);
+
+ proj[0] = Xcam[0] / Xcam[2];
+ proj[1] = Xcam[1] / Xcam[2];
+
+ radial_distort(&cam[9], proj);
+
+ proj[0] = proj[0] * cam[6] + cam[7];
+ proj[1] = proj[1] * cam[6] + cam[8];
+}
+
+// cam: 11 camera in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
+// r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
+// [C1 C2 C3]' is the camera center
+// f is the focal length in pixels
+// [u0 v0]' is the principal point
+// k1, k2 are radial distortion parameters
+// X: 3 point
+// feats: 2 feature (x,y coordinates)
+// reproj_err: 2
+// projection function:
+// Xcam = R * (X - C)
+// distorted = radial_distort(projective2euclidean(Xcam), radial_parameters)
+// proj = distorted * f + principal_point
+// err = sqsum(proj - measurement)
+void compute_reproj_error(double const *cam, double const *X, double const *w,
+ double const *feat, double *err) {
+ double proj[2];
+ project(cam, X, proj);
+
+ err[0] = (*w)*(proj[0] - feat[0]);
+ err[1] = (*w)*(proj[1] - feat[1]);
+}
+
+
+
+
+// n number of cameras
+// m number of points
+// p number of observations
+// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
+// r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
+// [C1 C2 C3]' is the camera center
+// f is the focal length in pixels
+// [u0 v0]' is the principal point
+// k1, k2 are radial distortion parameters
+// X: 3*m points
+// obs: 2*p observations (pairs cameraIdx, pointIdx)
+// feats: 2*p features (x,y coordinates corresponding to observations)
+// reproj_err: 2*p errors of observations
+// w_err: p weight "error" terms
+void ba_objective(
+ int n,
+ int m,
+ int p,
+ double const* cams,
+ double const* X,
+ double const* w,
+ int const* obs,
+ double const* feats,
+ double* reproj_err,
+ double* w_err
+)
+{
+ int i;
+ for (i = 0; i < p; i++)
+ {
+ int camIdx = obs[i * 2 + 0];
+ int ptIdx = obs[i * 2 + 1];
+ compute_reproj_error(
+ &cams[camIdx * BA_NCAMPARAMS],
+ &X[ptIdx * 3],
+ &w[i],
+ &feats[i * 2],
+ &reproj_err[2 * i]
+ );
+ }
+
+ for (i = 0; i < p; i++)
+ {
+ compute_zach_weight_error(&w[i], &w_err[i]);
+ }
+}
+
+extern int enzyme_const;
+extern int enzyme_dup;
+extern int enzyme_dupnoneed;
+void __enzyme_autodiff(...) noexcept;
+
+void dcompute_reproj_error(
+ double const* cam,
+ double * dcam,
+ double const* X,
+ double * dX,
+ double const* w,
+ double * wb,
+ double const* feat,
+ double *err,
+ double *derr
+)
+{
+ __enzyme_autodiff(compute_reproj_error,
+ enzyme_dup, cam, dcam,
+ enzyme_dup, X, dX,
+ enzyme_dup, w, wb,
+ enzyme_const, feat,
+ enzyme_dupnoneed, err, derr);
+}
+
+void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr) {
+ __enzyme_autodiff(compute_zach_weight_error,
+ enzyme_dup, w, dw,
+ enzyme_dupnoneed, err, derr);
+}
+
+}
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/lib.rs b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs
new file mode 100644
index 0000000..7efd43f
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs
@@ -0,0 +1,25 @@
+#![feature(autodiff)]
+#![allow(non_snake_case)]
+
+use std::autodiff::autodiff;
+pub mod safe;
+pub mod r#unsafe;
+
+static BA_NCAMPARAMS: usize = 11;
+
+#[no_mangle]
+pub extern "C" fn rust_dcompute_zach_weight_error(
+ w: *const f64,
+ dw: *mut f64,
+ err: *mut f64,
+ derr: *mut f64,
+) {
+ dcompute_zach_weight_error(w, dw, err, derr);
+}
+
+#[autodiff(dcompute_zach_weight_error, Reverse, Duplicated, Duplicated)]
+pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) {
+ let w = unsafe { *w };
+ unsafe { *err = 1. - w * w; }
+}
+
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/main.rs b/enzyme/benchmarks/ReverseMode/ba/src/main.rs
new file mode 100644
index 0000000..13f221b
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/main.rs
@@ -0,0 +1,26 @@
+use bars::{dcompute_reproj_error, dcompute_zach_weight_error};
+fn main() {
+ let cam = [0.0; 11];
+ let mut dcam = [0.0; 11];
+ let x = [0.0; 3];
+ let mut dx = [0.0; 3];
+ let w = [0.0; 1];
+ let mut dw = [0.0; 1];
+ let feat = [0.0; 2];
+ let mut err = [0.0; 2];
+ let mut derr = [0.0; 2];
+ dcompute_reproj_error(
+ &cam as *const [f64;11],
+ &mut dcam as *mut [f64;11],
+ &x as *const [f64;3],
+ &mut dx as *mut [f64;3],
+ &w as *const [f64;1],
+ &mut dw as *mut [f64;1],
+ &feat as *const [f64;2],
+ &mut err as *mut [f64;2],
+ &mut derr as *mut [f64;2],
+ );
+
+ let mut wb = 0.0;
+ dcompute_zach_weight_error(&w as *const f64, &mut dw as *mut f64, &mut err as *mut f64, &mut derr as *mut f64);
+}
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/safe.rs b/enzyme/benchmarks/ReverseMode/ba/src/safe.rs
new file mode 100644
index 0000000..3530c79
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/safe.rs
@@ -0,0 +1,204 @@
+use crate::BA_NCAMPARAMS;
+use crate::compute_zach_weight_error;
+use std::autodiff::autodiff;
+
+fn sqsum(x: &[f64]) -> f64 {
+ x.iter().map(|&v| v * v).sum()
+}
+
+#[inline]
+fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
+ [
+ a[1] * b[2] - a[2] * b[1],
+ a[2] * b[0] - a[0] * b[2],
+ a[0] * b[1] - a[1] * b[0],
+ ]
+}
+
+fn radial_distort(rad_params: &[f64], proj: &mut [f64]) {
+ let rsq = sqsum(proj);
+ let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq;
+ proj[0] = proj[0] * l;
+ proj[1] = proj[1] * l;
+}
+
+fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) {
+ let sqtheta = sqsum(rot);
+ if sqtheta != 0. {
+ let theta = sqtheta.sqrt();
+ let costheta = theta.cos();
+ let sintheta = theta.sin();
+ let theta_inverse = 1. / theta;
+ let mut w = [0.; 3];
+ for i in 0..3 {
+ w[i] = rot[i] * theta_inverse;
+ }
+ let w_cross_pt = cross(&w, &pt);
+ let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta);
+ for i in 0..3 {
+ rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
+ }
+ } else {
+ let rot_cross_pt = cross(&rot, &pt);
+ for i in 0..3 {
+ rotated_pt[i] = pt[i] + rot_cross_pt[i];
+ }
+ }
+}
+
+fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) {
+ let C = &cam[3..6];
+ let mut Xo = [0.; 3];
+ let mut Xcam = [0.; 3];
+
+ Xo[0] = X[0] - C[0];
+ Xo[1] = X[1] - C[1];
+ Xo[2] = X[2] - C[2];
+
+ rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam);
+
+ proj[0] = Xcam[0] / Xcam[2];
+ proj[1] = Xcam[1] / Xcam[2];
+
+ radial_distort(&cam[9..], proj);
+
+ proj[0] = proj[0] * cam[6] + cam[7];
+ proj[1] = proj[1] * cam[6] + cam[8];
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dcompute_reproj_error(
+ cam: *const [f64; 11],
+ dcam: *mut [f64; 11],
+ x: *const [f64; 3],
+ dx: *mut [f64; 3],
+ w: *const [f64; 1],
+ wb: *mut [f64; 1],
+ feat: *const [f64; 2],
+ err: *mut [f64; 2],
+ derr: *mut [f64; 2],
+) {
+ unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)};
+}
+
+#[autodiff(
+ dcompute_reproj_error,
+ Reverse,
+ Duplicated,
+ Duplicated,
+ Duplicated,
+ Const,
+ DuplicatedOnly
+)]
+pub fn compute_reproj_error(
+ cam: *const [f64; 11],
+ x: *const [f64; 3],
+ w: *const [f64; 1],
+ feat: *const [f64; 2],
+ err: *mut [f64; 2],
+) {
+ let cam = unsafe { &*cam };
+ let w = unsafe { *(*w).get_unchecked(0) };
+ let x = unsafe { &*x };
+ let feat = unsafe { &*feat };
+ let err = unsafe { &mut *err };
+ let mut proj = [0.; 2];
+ project(cam, x, &mut proj);
+ err[0] = w * (proj[0] - feat[0]);
+ err[1] = w * (proj[1] - feat[1]);
+}
+
+// n number of cameras
+// m number of points
+// p number of observations
+// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
+// r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
+// [C1 C2 C3]' is the camera center
+// f is the focal length in pixels
+// [u0 v0]' is the principal point
+// k1, k2 are radial distortion parameters
+// X: 3*m points
+// obs: 2*p observations (pairs cameraIdx, pointIdx)
+// feats: 2*p features (x,y coordinates corresponding to observations)
+// reproj_err: 2*p errors of observations
+// w_err: p weight "error" terms
+fn rust_ba_objective(
+ n: usize,
+ m: usize,
+ p: usize,
+ cams: &[f64],
+ x: &[f64],
+ w: &[f64],
+ obs: &[i32],
+ feats: &[f64],
+ reproj_err: &mut [f64],
+ w_err: &mut [f64],
+) {
+ assert_eq!(cams.len(), n * 11);
+ assert_eq!(x.len(), m * 3);
+ assert_eq!(w.len(), p);
+ assert_eq!(obs.len(), p * 2);
+ assert_eq!(feats.len(), p * 2);
+ assert_eq!(reproj_err.len(), p * 2);
+ assert_eq!(w_err.len(), p);
+
+ for i in 0..p {
+ let cam_idx = obs[i * 2 + 0] as usize;
+ let pt_idx = obs[i * 2 + 1] as usize;
+ let start = cam_idx * BA_NCAMPARAMS;
+ let cam: &[f64; 11] = unsafe {
+ cams[start..]
+ .get_unchecked(..11)
+ .try_into()
+ .unwrap_unchecked()
+ };
+ let x: &[f64; 3] = unsafe {
+ x[pt_idx * 3..]
+ .get_unchecked(..3)
+ .try_into()
+ .unwrap_unchecked()
+ };
+ let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() };
+ let feat: &[f64; 2] = unsafe {
+ feats[i * 2..]
+ .get_unchecked(..2)
+ .try_into()
+ .unwrap_unchecked()
+ };
+ let reproj_err: &mut [f64; 2] = unsafe {
+ reproj_err[i * 2..]
+ .get_unchecked_mut(..2)
+ .try_into()
+ .unwrap_unchecked()
+ };
+ compute_reproj_error(cam, x, w, feat, reproj_err);
+ }
+
+ for i in 0..p {
+ let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) };
+ compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64);
+ }
+}
+
+#[no_mangle]
+extern "C" fn rust2_ba_objective(
+ n: usize,
+ m: usize,
+ p: usize,
+ cams: *const f64,
+ x: *const f64,
+ w: *const f64,
+ obs: *const i32,
+ feats: *const f64,
+ reproj_err: *mut f64,
+ w_err: *mut f64,
+) {
+ let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) };
+ let x = unsafe { std::slice::from_raw_parts(x, m * 3) };
+ let w = unsafe { std::slice::from_raw_parts(w, p) };
+ let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) };
+ let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) };
+ let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) };
+ let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) };
+ rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err);
+}
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs b/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs
new file mode 100644
index 0000000..09f74be
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs
@@ -0,0 +1,140 @@
+use crate::BA_NCAMPARAMS;
+use crate::compute_zach_weight_error;
+use std::autodiff::autodiff;
+
+unsafe fn sqsum(x: *const f64, n: usize) -> f64 {
+ let mut sum = 0.;
+ for i in 0..n {
+ let v = unsafe { *x.add(i) };
+ sum += v * v;
+ }
+ sum
+}
+
+#[inline]
+unsafe fn cross(a: *const f64, b: *const f64, out: *mut f64) {
+ *out.add(0) = *a.add(1) * *b.add(2) - *a.add(2) * *b.add(1);
+ *out.add(1) = *a.add(2) * *b.add(0) - *a.add(0) * *b.add(2);
+ *out.add(2) = *a.add(0) * *b.add(1) - *a.add(1) * *b.add(0);
+}
+
+unsafe fn radial_distort(rad_params: *const f64, proj: *mut f64) {
+ let rsq = sqsum(proj, 2);
+ let l = 1. + *rad_params.add(0) * rsq + *rad_params.add(1) * rsq * rsq;
+ *proj.add(0) = *proj.add(0) * l;
+ *proj.add(1) = *proj.add(1) * l;
+}
+
+unsafe fn rodrigues_rotate_point(rot: *const f64, pt: *const f64, rotated_pt: *mut f64) {
+ let sqtheta = sqsum(rot, 3);
+ if sqtheta != 0. {
+ let theta = sqtheta.sqrt();
+ let costheta = theta.cos();
+ let sintheta = theta.sin();
+ let theta_inverse = 1. / theta;
+ let mut w = [0.; 3];
+ for i in 0..3 {
+ w[i] = *rot.add(i) * theta_inverse;
+ }
+ let mut w_cross_pt = [0.; 3];
+ cross(w.as_ptr(), pt, w_cross_pt.as_mut_ptr());
+ let tmp = (w[0] * *pt.add(0) + w[1] * *pt.add(1) + w[2] * *pt.add(2)) * (1. - costheta);
+ for i in 0..3 {
+ *rotated_pt.add(i) = *pt.add(i) * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
+ }
+ } else {
+ let mut rot_cross_pt = [0.; 3];
+ cross(rot, pt, rot_cross_pt.as_mut_ptr());
+ for i in 0..3 {
+ *rotated_pt.add(i) = *pt.add(i) + rot_cross_pt[i];
+ }
+ }
+}
+
+unsafe fn project(cam: *const f64, X: *const f64, proj: *mut f64) {
+ let C = cam.add(3);
+ let mut Xo = [0.; 3];
+ let mut Xcam = [0.; 3];
+
+ Xo[0] = *X.add(0) - *C.add(0);
+ Xo[1] = *X.add(1) - *C.add(1);
+ Xo[2] = *X.add(2) - *C.add(2);
+
+ rodrigues_rotate_point(cam, Xo.as_ptr(), Xcam.as_mut_ptr());
+
+ *proj.add(0) = Xcam[0] / Xcam[2];
+ *proj.add(1) = Xcam[1] / Xcam[2];
+
+ radial_distort(cam.add(9), proj);
+ *proj.add(0) = *proj.add(0) * *cam.add(6) + *cam.add(7);
+ *proj.add(1) = *proj.add(1) * *cam.add(6) + *cam.add(8);
+}
+
+#[no_mangle]
+pub unsafe extern "C" fn rust_unsafe_dcompute_reproj_error(
+ cam: *const f64,
+ dcam: *mut f64,
+ x: *const f64,
+ dx: *mut f64,
+ w: *const f64,
+ wb: *mut f64,
+ feat: *const f64,
+ err: *mut f64,
+ derr: *mut f64,
+) {
+ unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)};
+}
+
+
+#[autodiff(
+ dcompute_reproj_error,
+ Reverse,
+ Duplicated,
+ Duplicated,
+ Duplicated,
+ Const,
+ DuplicatedOnly
+)]
+pub unsafe fn compute_reproj_error(
+ cam: *const f64,
+ x: *const f64,
+ w: *const f64,
+ feat: *const f64,
+ err: *mut f64,
+) {
+ let mut proj = [0.; 2];
+ project(cam, x, proj.as_mut_ptr());
+ *err.add(0) = *w * (proj[0] - *feat.add(0));
+ *err.add(1) = *w * (proj[1] - *feat.add(1));
+}
+
+#[no_mangle]
+unsafe extern "C" fn rust2_unsafe_ba_objective(
+ n: usize,
+ m: usize,
+ p: usize,
+ cams: *const f64,
+ x: *const f64,
+ w: *const f64,
+ obs: *const i32,
+ feats: *const f64,
+ reproj_err: *mut f64,
+ w_err: *mut f64,
+) {
+ for i in 0..p {
+ let cam_idx = *obs.add(i * 2 + 0) as usize;
+ let pt_idx = *obs.add(i * 2 + 1) as usize;
+ let start = cam_idx * BA_NCAMPARAMS;
+
+ let cam: *const f64 = cams.add(start);
+ let x: *const f64 = x.add(pt_idx * 3);
+ let w: *const f64 = w.add(i);
+ let feat: *const f64 = feats.add(i * 2);
+ let reproj_err: *mut f64 = reproj_err.add(i * 2);
+ compute_reproj_error(cam, x, w, feat, reproj_err);
+ }
+
+ for i in 0..p {
+ compute_zach_weight_error(w.add(i), w_err.add(i));
+ }
+}
diff --git a/enzyme/benchmarks/ReverseMode/fft/Cargo.lock b/enzyme/benchmarks/ReverseMode/fft/Cargo.lock
new file mode 100644
index 0000000..44847ec
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/Cargo.lock
@@ -0,0 +1,7 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "fft"
+version = "0.1.0"
diff --git a/enzyme/benchmarks/ReverseMode/fft/Cargo.toml b/enzyme/benchmarks/ReverseMode/fft/Cargo.toml
new file mode 100644
index 0000000..cf8862d
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "fft"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+
+[lib]
+crate-type = ["lib"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
diff --git a/enzyme/benchmarks/ReverseMode/fft/Makefile.make b/enzyme/benchmarks/ReverseMode/fft/Makefile.make
index 17ea03a..b9385cd 100644
--- a/enzyme/benchmarks/ReverseMode/fft/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/fft/Makefile.make
@@ -7,6 +7,9 @@
clean:
rm -f *.ll *.o results.txt results.json
+$(dir)/benchmarks/ReverseMode/fft/target/release/libfft.a: src/lib.rs Cargo.toml
+ RUSTFLAGS="-Z autodiff=Enable" cargo +enzyme rustc --release --lib --crate-type=staticlib
+
%-unopt.ll: %.cpp
clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,7 +19,7 @@
%-opt.ll: %-raw.ll
opt $^ -o $@ -S
-fft.o: fft-opt.ll
+fft.o: fft-opt.ll $(dir)/benchmarks/ReverseMode/fft/target/release/libfft.a
clang++ $(BENCH) -pthread -O2 $^ -o $@ $(BENCHLINK) -lpthread -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
#clang++ $(LOAD) $(BENCH) fft.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o fft.o -lpthread $(BENCHLINK) -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
diff --git a/enzyme/benchmarks/ReverseMode/fft/fft.cpp b/enzyme/benchmarks/ReverseMode/fft/fft.cpp
index cf9459b..799b7b1 100644
--- a/enzyme/benchmarks/ReverseMode/fft/fft.cpp
+++ b/enzyme/benchmarks/ReverseMode/fft/fft.cpp
@@ -1,237 +1,368 @@
+#include <adept.h>
+#include <adept_source.h>
+#include <inttypes.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
-#include <sys/time.h>
-#include <stdlib.h>
-#include <math.h>
-#include <inttypes.h>
#include <string.h>
-#include <adept_source.h>
-#include <adept.h>
+#include <sys/time.h>
using adept::adouble;
-template<typename Return, typename... T>
-Return __enzyme_autodiff(T...);
+template <typename Return, typename... T> Return __enzyme_autodiff(T...);
float tdiff(struct timeval *start, struct timeval *end) {
- return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec);
+ return (end->tv_sec - start->tv_sec) + 1e-6 * (end->tv_usec - start->tv_usec);
}
#include "fft.h"
-void foobar(double* data, unsigned len) {
+void foobar(double *data, size_t len) {
fft(data, len);
ifft(data, len);
}
-void afoobar(aVector& data, unsigned len) {
+void afoobar(aVector &data, size_t len) {
fft(data, len);
ifft(data, len);
}
extern "C" {
- int enzyme_dupnoneed;
+int enzyme_dupnoneed;
}
-static double foobar_and_gradient(unsigned len) {
- double *inp = new double[2*len];
- for(int i=0; i<2*len; i++) inp[i] = 2.0;
- double *dinp = new double[2*len];
- for(int i=0; i<2*len; i++) dinp[i] = 1.0;
- __enzyme_autodiff<void>(foobar, enzyme_dupnoneed, inp, dinp, len);
- double res = dinp[0];
- delete[] dinp;
- delete[] inp;
- return res;
+extern "C" void rust_unsafe_dfoobar(size_t n, double *data, double *ddata);
+extern "C" void rust_unsafe_foobar(size_t n, double *data);
+extern "C" void rust_dfoobar(size_t n, double *data, double *ddata);
+extern "C" void rust_foobar(size_t n, double *data);
+
+static double rust_unsafe_foobar_and_gradient(size_t len) {
+ double *inp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ inp[i] = 2.0;
+ double *dinp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ dinp[i] = 1.0;
+ rust_unsafe_dfoobar(len, inp, dinp);
+ double res = dinp[0];
+ delete[] dinp;
+ delete[] inp;
+ return res;
}
-static double afoobar_and_gradient(unsigned len) {
- adept::Stack stack;
-
- aVector x(2*len);
- for(int i=0; i<2*len; i++) x(i) = 2.0;
- stack.new_recording();
- afoobar(x, len);
- for(int i=0; i<2*len; i++)
- x(i).set_gradient(1.0);
- stack.compute_adjoint();
-
- double *dinp = new double[2*len];
- for(int i=0; i<2*len; i++)
- dinp[i] = x(i).get_gradient();
- double res = dinp[0];
- delete[] dinp;
- return res;
+static double rust_foobar_and_gradient(size_t len) {
+ double *inp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ inp[i] = 2.0;
+ double *dinp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ dinp[i] = 1.0;
+ rust_dfoobar(len, inp, dinp);
+ double res = dinp[0];
+ delete[] dinp;
+ delete[] inp;
+ return res;
}
-
-static double tfoobar_and_gradient(unsigned len) {
- double *inp = new double[2*len];
- for(int i=0; i<2*len; i++) inp[i] = 2.0;
- double *dinp = new double[2*len];
- for(int i=0; i<2*len; i++) dinp[i] = 1.0;
- foobar_b(inp, dinp, len);
- double res = dinp[0];
- delete[] dinp;
- delete[] inp;
- return res;
+__attribute__((noinline)) static double foobar_and_gradient(size_t len) {
+ double *inp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ inp[i] = 2.0;
+ double *dinp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ dinp[i] = 1.0;
+ __enzyme_autodiff<void>(foobar, enzyme_dupnoneed, inp, dinp, len);
+ double res = dinp[0];
+ delete[] dinp;
+ delete[] inp;
+ return res;
}
-static void adept_sincos(double inp, unsigned len) {
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
- double *x = new double[2*len];
- for(int i=0; i<2*len; i++) x[i] = 2.0;
- foobar(x, len);
- double res = x[0];
-
- gettimeofday(&end, NULL);
- printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res);
- delete[] x;
- }
-
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
+static double afoobar_and_gradient(size_t len) {
adept::Stack stack;
- aVector x(2*len);
- for(int i=0; i<2*len; i++) x[i] = 2.0;
- // stack.new_recording();
+ aVector x(2 * len);
+ for (size_t i = 0; i < 2 * len; i++)
+ x(i) = 2.0;
+ stack.new_recording();
afoobar(x, len);
- double res = x(0).value();
+ for (size_t i = 0; i < 2 * len; i++)
+ x(i).set_gradient(1.0);
+ stack.compute_adjoint();
- gettimeofday(&end, NULL);
- printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res);
+ double *dinp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ dinp[i] = x(i).get_gradient();
+ double res = dinp[0];
+ delete[] dinp;
+ return res;
+}
+
+static double tfoobar_and_gradient(size_t len) {
+ double *inp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ inp[i] = 2.0;
+ double *dinp = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ dinp[i] = 1.0;
+ foobar_b(inp, dinp, len);
+ double res = dinp[0];
+ delete[] dinp;
+ delete[] inp;
+ return res;
+}
+
+static void adept_sincos(double inp, size_t len) {
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double *x = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ foobar(x, len);
+ double res = x[0];
+
+ gettimeofday(&end, NULL);
+ printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res);
+ delete[] x;
}
{
- struct timeval start, end;
- gettimeofday(&start, NULL);
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
- double res2 = afoobar_and_gradient(len);
+ adept::Stack stack;
- gettimeofday(&end, NULL);
- printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+ aVector x(2 * len);
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ // stack.new_recording();
+ afoobar(x, len);
+ double res = x(0).value();
+
+ gettimeofday(&end, NULL);
+ printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res);
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double res2 = afoobar_and_gradient(len);
+
+ gettimeofday(&end, NULL);
+ printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
}
}
+static void tapenade_sincos(double inp, size_t len) {
-static void tapenade_sincos(double inp, unsigned len) {
+ // {
+ // struct timeval start, end;
+ // gettimeofday(&start, NULL);
+
+ // double *x = new double[2*len];
+ // for(size_t i=0; i<2*len; i++) x[i] = 2.0;
+ // foobar(x, len);
+ // double res = x[0];
+
+ // gettimeofday(&end, NULL);
+ // printf("Tapenade real %0.6f res=%f\n", tdiff(&start, &end), res);
+ // delete[] x;
+ // }
+
+ // {
+ // struct timeval start, end;
+ // gettimeofday(&start, NULL);
+
+ // double* x = new double[2*len];
+ // for(size_t i=0; i<2*len; i++) x[i] = 2.0;
+ // foobar(x, len);
+ // double res = x[0];
+
+ // gettimeofday(&end, NULL);
+ // printf("Tapenade forward %0.6f res=%f\n", tdiff(&start, &end), res);
+ // delete[] x;
+ // }
+
+ // {
+ // struct timeval start, end;
+ // gettimeofday(&start, NULL);
+
+ // double res2 = tfoobar_and_gradient(len);
+
+ // gettimeofday(&end, NULL);
+ // printf("Tapenade combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+ // }
+}
+
+static void enzyme_sincos(double inp, size_t len) {
{
- struct timeval start, end;
- gettimeofday(&start, NULL);
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
- double *x = new double[2*len];
- for(int i=0; i<2*len; i++) x[i] = 2.0;
- foobar(x, len);
- double res = x[0];
+ double *x = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ foobar(x, len);
+ double res = x[0];
- gettimeofday(&end, NULL);
- printf("Tapenade real %0.6f res=%f\n", tdiff(&start, &end), res);
- delete[] x;
+ gettimeofday(&end, NULL);
+ printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res);
+ delete[] x;
}
{
- struct timeval start, end;
- gettimeofday(&start, NULL);
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
- double* x = new double[2*len];
- for(int i=0; i<2*len; i++) x[i] = 2.0;
- foobar(x, len);
- double res = x[0];
+ double *x = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ foobar(x, len);
+ double res = x[0];
- gettimeofday(&end, NULL);
- printf("Tapenade forward %0.6f res=%f\n", tdiff(&start, &end), res);
- delete[] x;
+ gettimeofday(&end, NULL);
+ printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res);
+ delete[] x;
}
{
- struct timeval start, end;
- gettimeofday(&start, NULL);
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
- double res2 = tfoobar_and_gradient(len);
+ double res2 = foobar_and_gradient(len);
- gettimeofday(&end, NULL);
- printf("Tapenade combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+ gettimeofday(&end, NULL);
+ printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
}
}
-static void enzyme_sincos(double inp, unsigned len) {
+static void enzyme_unsafe_rust_sincos(double inp, size_t len) {
{
- struct timeval start, end;
- gettimeofday(&start, NULL);
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
- double *x = new double[2*len];
- for(int i=0; i<2*len; i++) x[i] = 2.0;
- foobar(x, len);
- double res = x[0];
+ double *x = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ rust_unsafe_foobar(len, x);
+ double res = x[0];
- gettimeofday(&end, NULL);
- printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res);
- delete[] x;
+ gettimeofday(&end, NULL);
+ printf("Enzyme (unsafe Rust) real %0.6f res=%f\n", tdiff(&start, &end),
+ res);
+ delete[] x;
}
{
- struct timeval start, end;
- gettimeofday(&start, NULL);
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
- double *x = new double[2*len];
- for(int i=0; i<2*len; i++) x[i] = 2.0;
- foobar(x, len);
- double res = x[0];
+ double *x = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ rust_unsafe_foobar(len, x);
+ double res = x[0];
- gettimeofday(&end, NULL);
- printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res);
- delete[] x;
+ gettimeofday(&end, NULL);
+ printf("Enzyme (unsafe Rust) forward %0.6f res=%f\n", tdiff(&start, &end),
+ res);
+ delete[] x;
}
{
- struct timeval start, end;
- gettimeofday(&start, NULL);
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
- double res2 = foobar_and_gradient(len);
+ double res2 = rust_unsafe_foobar_and_gradient(len);
- gettimeofday(&end, NULL);
- printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+ gettimeofday(&end, NULL);
+ printf("Enzyme (unsafe Rust) combined %0.6f res'=%f\n", tdiff(&start, &end),
+ res2);
}
}
+static void enzyme_rust_sincos(double inp, size_t len) {
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double *x = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ rust_foobar(len, x);
+ double res = x[0];
+
+ gettimeofday(&end, NULL);
+ printf("Enzyme (Rust) real %0.6f res=%f\n", tdiff(&start, &end), res);
+ delete[] x;
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double *x = new double[2 * len];
+ for (size_t i = 0; i < 2 * len; i++)
+ x[i] = 2.0;
+ rust_foobar(len, x);
+ double res = x[0];
+
+ gettimeofday(&end, NULL);
+ printf("Enzyme (Rust) forward %0.6f res=%f\n", tdiff(&start, &end), res);
+ delete[] x;
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double res2 = rust_foobar_and_gradient(len);
+
+ gettimeofday(&end, NULL);
+ printf("Enzyme (Rust) combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+ }
+}
/* Function to check if x is power of 2*/
-bool isPowerOfTwo (int x)
-{
- /* First x in the below expression is for the case when x is 0 */
- return x && (!(x&(x-1)));
+bool isPowerOfTwo(size_t x) {
+ /* First x in the below expression is for the case when x is 0 */
+ return x && (!(x & (x - 1)));
}
-unsigned max(unsigned A, unsigned B){
- if (A>B) return A;
+size_t max(size_t A, size_t B) {
+ if (A > B)
+ return A;
return B;
}
-int main(int argc, char** argv) {
+int main(int argc, char **argv) {
if (argc < 2) {
printf("usage %s n [must be power of 2]\n", argv[0]);
return 1;
}
- unsigned N = atoi(argv[1]);
+ size_t N = atol(argv[1]);
if (!isPowerOfTwo(N)) {
printf("usage %s n [must be power of 2]\n", argv[0]);
return 1;
}
double inp = -2.1;
- for(unsigned iters=max(1, N>>5); iters <= N; iters*=2) {
- printf("iters=%d\n", iters);
+ size_t iters = max(1, N >> 0);
+ for (size_t i = 0; i < 5; i++) {
+ printf("iters=%zu\n", iters);
+#if CPP
adept_sincos(inp, iters);
tapenade_sincos(inp, iters);
enzyme_sincos(inp, iters);
+#else
+ enzyme_rust_sincos(inp, iters);
+ enzyme_unsafe_rust_sincos(inp, iters);
+#endif
}
}
diff --git a/enzyme/benchmarks/ReverseMode/fft/fft.h b/enzyme/benchmarks/ReverseMode/fft/fft.h
index 809196b..fad3c7d 100644
--- a/enzyme/benchmarks/ReverseMode/fft/fft.h
+++ b/enzyme/benchmarks/ReverseMode/fft/fft.h
@@ -1,71 +1,75 @@
#ifndef _fft_h_
#define _fft_h_
-#include <adept_source.h>
#include <adept.h>
#include <adept_arrays.h>
+#include <adept_source.h>
using adept::adouble;
using adept::aVector;
-
/*
A classy FFT and Inverse FFT C++ class library
Author: Tim Molteno, tim@physics.otago.ac.nz
- Based on the article "A Simple and Efficient FFT Implementation in C++" by Volodymyr Myrnyy
- with just a simple Inverse FFT modification.
+ Based on the article "A Simple and Efficient FFT Implementation in C++" by
+ Volodymyr Myrnyy with just a simple Inverse FFT modification.
Licensed under the GPL v3.
*/
-
#include <cmath>
-inline void swap(double* a, double* b) {
- double temp=*a;
+inline void swap(double *a, double *b) {
+ double temp = *a;
*a = *b;
*b = temp;
}
-static void recursiveApply(double* data, int iSign, unsigned N) {
- if (N == 1) return;
- recursiveApply(data, iSign, N/2);
- recursiveApply(data+N, iSign, N/2);
+static void recursiveApply(double *__restrict data, size_t N, int iSign) {
+ if (N == 1)
+ return;
+ recursiveApply(data, N / 2, iSign);
+ recursiveApply(data + N, N / 2, iSign);
- double wtemp = iSign*sin(M_PI/N);
- double wpi = -iSign*sin(2*M_PI/N);
- double wpr = -2.0*wtemp*wtemp;
+ double wtemp = iSign * sin(M_PI / N);
+ double wpi = -iSign * sin(2 * (M_PI / N));
+ double wpr = -2.0 * wtemp * wtemp;
double wr = 1.0;
double wi = 0.0;
- for (unsigned i=0; i<N; i+=2) {
- int iN = i+N;
+ for (size_t i = 0; i < N; i += 2) {
+ size_t iN = i + N;
+ double *__restrict ay = &data[i + 1];
+ double *__restrict ax = &data[i];
+ double *__restrict by = &data[iN + 1];
+ double *__restrict bx = &data[iN];
- double tempr = data[iN]*wr - data[iN+1]*wi;
- double tempi = data[iN]*wi + data[iN+1]*wr;
+ double tempr = *bx * wr - *by * wi;
+ double tempi = *bx * wi + *by * wr;
- data[iN] = data[i]-tempr;
- data[iN+1] = data[i+1]-tempi;
- data[i] += tempr;
- data[i+1] += tempi;
+ *bx = *ax - tempr;
+ *by = *ay - tempi;
+ *ax += tempr;
+ *ay += tempi;
wtemp = wr;
- wr += wr*wpr - wi*wpi;
- wi += wi*wpr + wtemp*wpi;
+ wr = wr * (wpr + 1.) - wi * wpi;
+ wi = wi * (wpr + 1.) + wtemp * wpi;
}
}
-static void scramble(double* data, unsigned N) {
- int j=1;
- for (int i=1; i<2*N; i+=2) {
- if (j>i) {
- swap(&data[j-1], &data[i-1]);
+static void scramble(double *data, size_t N) {
+ size_t j = 1;
+ for (size_t ii = 0; ii < N; ii++) {
+ size_t i = 2 * ii + 1;
+ if (j > i) {
+ swap(&data[j - 1], &data[i - 1]);
swap(&data[j], &data[i]);
}
- int m = N;
- while (m>=2 && j>m) {
+ size_t m = N;
+ while (m >= 2 && j > m) {
j -= m;
m >>= 1;
}
@@ -73,69 +77,71 @@
}
}
-static void rescale(double* data, unsigned N) {
- double scale = ((double)1)/N;
- for (unsigned i=0; i<2*N; i++) {
+static void rescale(double *data, size_t N) {
+ double scale = ((double)1) / N;
+ for (size_t i = 0; i < 2 * N; i++) {
data[i] *= scale;
}
}
-static void fft(double* data, unsigned N) {
+static void fft(double *data, size_t N) {
scramble(data, N);
- recursiveApply(data,1, N);
+ recursiveApply(data, N, 1);
}
-static void ifft(double* data, unsigned N) {
+static void ifft(double *data, size_t N) {
scramble(data, N);
- recursiveApply(data,-1, N);
+ recursiveApply(data, N, -1);
rescale(data, N);
}
-
-
-inline void swapad(adept::ActiveReference<double> a, adept::ActiveReference<double> b) {
- adouble temp=a;
+inline void swapad(adept::ActiveReference<double> a,
+ adept::ActiveReference<double> b) {
+ adouble temp = a;
a = b;
b = temp;
}
-static void recursiveApply(aVector data, int iSign, unsigned N) {
- if (N == 1) return;
- recursiveApply(data, iSign, N/2);
- recursiveApply(data(adept::range(N,adept::end)), iSign, N/2);
+static void recursiveApply(aVector data, size_t N, int iSign) {
+ if (N == 1)
+ return;
+ recursiveApply(data, N / 2, iSign);
+ recursiveApply(data(adept::range(N, adept::end)), N / 2, iSign);
- adouble wtemp = iSign*std::sin(M_PI/N);
- adouble wpi = -iSign*std::sin(2*M_PI/N);
- adouble wpr = -2.0*wtemp*wtemp;
+ adouble wtemp = iSign * std::sin(M_PI / N);
+ adouble wpi = -iSign * std::sin(2 * (M_PI / N));
+ adouble wpr = -2.0 * wtemp * wtemp;
adouble wr = 1.0;
adouble wi = 0.0;
- for (unsigned i=0; i<N; i+=2) {
- int iN = i+N;
+ for (size_t ii = 0; ii < N / 2; ii++) {
+ size_t i = 2 * ii;
+ size_t iN = i + N;
- adouble tempr = data(iN)*wr - data(iN+1)*wi;
- adouble tempi = data(iN)*wi + data(iN+1)*wr;
+ adouble tempr = data(iN) * wr - data(iN + 1) * wi;
+ adouble tempi = data(iN) * wi + data(iN + 1) * wr;
- data(iN) = data(i)-tempr;
- data(iN+1) = data(i+1)-tempi;
+ data(iN) = data(i) - tempr;
+ data(iN + 1) = data(i + 1) - tempi;
data(i) += tempr;
- data(i+1) += tempi;
+ data(i + 1) += tempi;
wtemp = wr;
- wr += wr*wpr - wi*wpi;
- wi += wi*wpr + wtemp*wpi;
+ wr = wr * (wpr + 1.) - wi * wpi;
+ wi = wi * (wpr + 1.) + wtemp * wpi;
}
}
-static void scramble(aVector data, unsigned N) {
- int j=1;
- for (int i=1; i<2*N; i+=2) {
- if (j>i) {
- swapad(data(j-1), data(i-1));
+static void scramble(aVector data, size_t N) {
+ size_t j = 1;
+ for (size_t ii = 0; ii < N; ii++) {
+ size_t i = 2 * ii + 1;
+ if (j > i) {
+ swapad(data(j - 1), data(i - 1));
swapad(data(j), data(i));
}
- int m = N;
- while (m>=2 && j>m) {
+ size_t m = N;
+ while (m >= 2 && j > m) {
j -= m;
m >>= 1;
}
@@ -143,21 +149,21 @@
}
}
-static void rescale(aVector data, unsigned N) {
- adouble scale = ((double)1)/N;
- for (unsigned i=0; i<2*N; i++) {
+static void rescale(aVector data, size_t N) {
+ adouble scale = ((double)1) / N;
+ for (size_t i = 0; i < 2 * N; i++) {
data[i] *= scale;
}
}
-static void fft(aVector data, unsigned N) {
+static void fft(aVector data, size_t N) {
scramble(data, N);
- recursiveApply(data,1, N);
+ recursiveApply(data, N, 1);
}
-static void ifft(aVector data, unsigned N) {
+static void ifft(aVector data, size_t N) {
scramble(data, N);
- recursiveApply(data,-1, N);
+ recursiveApply(data, N, -1);
rescale(data, N);
}
@@ -165,260 +171,308 @@
extern "C" {
/* Generated by TAPENADE (INRIA, Ecuador team)
- Tapenade 3.15 (master) - 15 Apr 2020 11:54
+ Tapenade 3.16 (bugfix_servletAD) - 4 Jan 2024 17:44
*/
#include <adBuffer.h>
+#include <adStack.h>
+#include <math.h>
/*
- Differentiation of recursiveApply in reverse (adjoint) mode (with options context):
- gradient of useful results: *data
- with respect to varying inputs: *data
- Plus diff mem management of: data:in
-*/
-static void recursiveApply_b(double *data, double *datab, int iSign, unsigned
- int N) {
- int arg1;
- double *arg10;
- double *arg10b;
- int arg2;
- if (N != 1) {
- arg1 = N/2;
- arg10b = datab + N;
- arg10 = data + N;
- arg2 = N/2;
- double wtemp = iSign*sin(3.1415926536/N);
- double wpi = -iSign*sin(2*3.1415926536/N);
- double wpr = -2.0*wtemp*wtemp;
- double wr = 1.0;
- double wi = 0.0;
- for (int i = 0; i <= N-1; i += 2) {
- int iN = i + N;
- double tempr = data[iN]*wr - data[iN+1]*wi;
- double tempi = data[iN]*wi + data[iN+1]*wr;
- double tmp;
- double tmp0;
- wtemp = wr;
- pushReal8(wr);
- wr = wr + (wr*wpr - wi*wpi);
- pushReal8(wi);
- wi = wi + (wi*wpr + wtemp*wpi);
- pushInteger4(iN);
- }
- pushPointer8(arg10b);
- pushInteger4(arg2);
- pushInteger4(arg1);
- popInteger4(&arg1);
- popInteger4(&arg2);
- popPointer8((void **)&arg10b);
- for (int i = N-(N-1)%2-1; i >= 0; i -= 2) {
- int iN;
- double tempr;
- double temprb = 0.0;
- double tempi;
- double tempib = 0.0;
- double tmpb;
- double tmpb0;
- popInteger4(&iN);
- tmpb0 = datab[iN + 1];
- popReal8(&wi);
- popReal8(&wr);
- tempib = datab[i + 1] - tmpb0;
- temprb = datab[i];
- datab[iN + 1] = 0.0;
- datab[i + 1] = datab[i + 1] + tmpb0;
- tmpb = datab[iN];
- datab[iN] = 0.0;
- datab[i] = datab[i] + tmpb;
- temprb = temprb - tmpb;
- datab[iN + 1] = datab[iN + 1] + wr*tempib - wi*temprb;
- datab[iN] = datab[iN] + wi*tempib + wr*temprb;
- }
- recursiveApply_b(arg10, arg10b, iSign, arg2);
- recursiveApply_b(data, datab, iSign, arg1);
- }
-}
-
-static void recursiveApply_nodiff(double *data, int iSign, unsigned int N) {
- int arg1;
- double *arg10;
- int arg2;
- if (N == 1)
- return;
- else {
- arg1 = N/2;
- recursiveApply_nodiff(data, iSign, arg1);
- arg10 = data + N;
- arg2 = N/2;
- recursiveApply_nodiff(arg10, iSign, arg2);
- double wtemp = iSign*sin(3.1415926536/N);
- double wpi = -iSign*sin(2*3.1415926536/N);
- double wpr = -2.0*wtemp*wtemp;
- double wr = 1.0;
- double wi = 0.0;
- for (int i = 0; i <= N-1; i += 2) {
- int iN = i + N;
- double tempr = data[iN]*wr - data[iN+1]*wi;
- double tempi = data[iN]*wi + data[iN+1]*wr;
- data[iN] = data[i] - tempr;
- data[iN + 1] = data[i + 1] - tempi;
- data[i] += tempr;
- data[i + 1] += tempi;
- wtemp = wr;
- wr += wr*wpr - wi*wpi;
- wi += wi*wpr + wtemp*wpi;
- }
- }
-}
-
-/*
- Differentiation of swap in reverse (adjoint) mode (with options context):
+ Differentiation of swap in reverse (adjoint) mode:
gradient of useful results: *a *b
with respect to varying inputs: *a *b
Plus diff mem management of: a:in b:in
*/
-static void swap_b(double *a, double *ab, double *b, double *bb) {
- double temp = *a;
- double tempb = 0.0;
- tempb = *bb;
- *bb = *ab;
- *ab = tempb;
+inline void swap_b(double *a, double *ab, double *b, double *bb) {
+ double temp = *a;
+ double tempb = 0.0;
+ *a = *b;
+ *b = temp;
+ tempb = *bb;
+ *bb = *ab;
+ *ab = tempb;
}
-static void swap_nodiff(double *a, double *b) {
- double temp = *a;
- *a = *b;
- *b = temp;
+inline void swap_c(double *a, double *b) {
+ double temp = *a;
+ *a = *b;
+ *b = temp;
}
-/*
- Differentiation of scramble in reverse (adjoint) mode (with options context):
- gradient of useful results: *data
- with respect to varying inputs: *data
- Plus diff mem management of: data:in
-*/
-static void scramble_b(double *data, double *datab, unsigned int N) {
- int j = 1;
- int branch;
- for (int i = 1; i <= 2*N-1; i += 2) {
- int adCount;
- if (j > i)
- pushControl1b(0);
- else
- pushControl1b(1);
- int m = N;
- adCount = 0;
- while(m >= 2 && j > m) {
- pushInteger4(j);
- j = j - m;
- m = m >> 1;
- adCount = adCount + 1;
- }
- pushInteger4(adCount);
- pushInteger4(j);
- j = j + m;
+static void recursiveApply_c(double *data, int iSign, size_t N) {
+ size_t arg1;
+ double *arg10;
+ size_t arg2;
+ if (N == 1)
+ return;
+ else {
+ arg1 = N / 2;
+ recursiveApply_c(data, iSign, arg1);
+ arg10 = data + N;
+ arg2 = N / 2;
+ recursiveApply_c(arg10, iSign, arg2);
+ double wtemp = iSign * sin(3.14 / N);
+ double wpi = -iSign * sin(2 * 3.14 / N);
+ double wpr = -2.0 * wtemp * wtemp;
+ double wr = 1.0;
+ double wi = 0.0;
+ for (size_t ii = 0; ii < N / 2; ii++) {
+ size_t i = 2 * ii;
+ size_t iN = i + N;
+ double tempr = data[iN] * wr - data[iN + 1] * wi;
+ double tempi = data[iN] * wi + data[iN + 1] * wr;
+ data[iN] = data[i] - tempr;
+ data[iN + 1] = data[i + 1] - tempi;
+ data[i] += tempr;
+ data[i + 1] += tempi;
+ wtemp = wr;
+ wr += wr * wpr - wi * wpi;
+ wi += wi * wpr + wtemp * wpi;
}
- for (int i = 2*N-(2*N-2)%2-1; i >= 1; i -= 2) {
- int m;
- int adCount;
- int i0;
- popInteger4(&j);
- popInteger4(&adCount);
- for (i0 = 1; i0 < adCount+1; ++i0)
- popInteger4(&j);
- popControl1b(&branch);
- if (branch == 0) {
- swap_b(&(data[j]), &(datab[j]), &(data[i]), &(datab[i]));
- swap_b(&(data[j - 1]), &(datab[j - 1]), &(data[i - 1]), &(datab[i
- - 1]));
- }
+ }
+}
+
+/*
+ Differentiation of recursiveApply in reverse (adjoint) mode:
+ gradient of useful results: *data
+ with respect to varying inputs: *data
+ Plus diff mem management of: data:in
+*/
+static void recursiveApply_b(double *data, double *datab, int iSign, size_t N) {
+ size_t arg1;
+ double *arg10;
+ double *arg10b;
+ size_t arg2;
+ int branch;
+ if (N != 1) {
+ arg1 = N / 2;
+ pushReal8(*data);
+ recursiveApply_c(data, iSign, arg1);
+ arg10b = datab + N;
+ arg10 = data + N;
+ arg2 = N / 2;
+ if (arg10) {
+ pushReal8(*arg10);
+ pushControl1b(1);
+ } else
+ pushControl1b(0);
+ recursiveApply_c(arg10, iSign, arg2);
+ double wtemp = iSign * sin(3.14 / N);
+ double wpi = -iSign * sin(2 * 3.14 / N);
+ double wpr = -2.0 * wtemp * wtemp;
+ double wr = 1.0;
+ double wi = 0.0;
+ for (size_t ii = 0; ii < N / 2; ii++) {
+ size_t i = 2 * ii;
+ int iN = i + N;
+ double tempr = data[iN] * wr - data[iN + 1] * wi;
+ double tempi = data[iN] * wi + data[iN + 1] * wr;
+ double temprb;
+ double tempib;
+ double tmp;
+ double tmp0;
+ tmp = data[i] - tempr;
+ data[iN] = tmp;
+ tmp0 = data[i + 1] - tempi;
+ data[iN + 1] = tmp0;
+ data[i] = data[i] + tempr;
+ data[i + 1] = data[i + 1] + tempi;
+ wtemp = wr;
+ pushReal8(wr);
+ wr = wr + (wr * wpr - wi * wpi);
+ pushReal8(wi);
+ wi = wi + (wi * wpr + wtemp * wpi);
+ pushInteger4(iN);
}
-}
-
-static void scramble_nodiff(double *data, unsigned int N) {
- int j = 1;
- for (int i = 1; i <= 2*N-1; i += 2) {
- if (j > i) {
- swap_nodiff(&(data[j - 1]), &(data[i - 1]));
- swap_nodiff(&(data[j]), &(data[i]));
- }
- int m = N;
- while(m >= 2 && j > m) {
- j -= m;
- m >>= 1;
- }
- j += m;
+ for (size_t i = N - (N - 1) % 2 - 1; i >= 0; i -= 2) {
+ int iN;
+ double tempr;
+ double temprb = 0.0;
+ double tempi;
+ double tempib = 0.0;
+ double tmpb;
+ double tmpb0;
+ popInteger4(&iN);
+ tmpb0 = datab[iN + 1];
+ popReal8(&wi);
+ popReal8(&wr);
+ tempib = datab[i + 1] - tmpb0;
+ temprb = datab[i];
+ datab[iN + 1] = 0.0;
+ datab[i + 1] = datab[i + 1] + tmpb0;
+ tmpb = datab[iN];
+ datab[iN] = 0.0;
+ datab[i] = datab[i] + tmpb;
+ temprb = temprb - tmpb;
+ datab[iN + 1] = datab[iN + 1] + wr * tempib - wi * temprb;
+ datab[iN] = datab[iN] + wi * tempib + wr * temprb;
}
+ popControl1b(&branch);
+ if (branch == 1)
+ popReal8(arg10);
+ recursiveApply_b(arg10, arg10b, iSign, arg2);
+ popReal8(data);
+ recursiveApply_b(data, datab, iSign, arg1);
+ }
}
/*
- Differentiation of rescale in reverse (adjoint) mode (with options context):
+ Differentiation of scramble in reverse (adjoint) mode:
gradient of useful results: *data
with respect to varying inputs: *data
Plus diff mem management of: data:in
*/
-static void rescale_b(double *data, double *datab, unsigned int N) {
- double scale = (double)1/N;
- pushReal8(scale);
- popReal8(&scale);
- for (int i = 2*N-1; i > -1; --i)
- datab[i] = scale*datab[i];
+static void scramble_b(double *data, double *datab, size_t N) {
+ int j = 1;
+ int branch;
+ for (size_t ii = 0; ii < N; ii++) {
+ size_t i = 2 * ii + 1;
+ int adCount;
+ if (j > i) {
+ pushReal8(data[i - 1]);
+ pushReal8(data[j - 1]);
+ swap_c(&(data[j - 1]), &(data[i - 1]));
+ pushReal8(data[i]);
+ pushReal8(data[j]);
+ swap_c(&(data[j]), &(data[i]));
+ pushControl1b(0);
+ } else
+ pushControl1b(1);
+ size_t m = N;
+ adCount = 0;
+ while (m >= 2 && j > m) {
+ pushInteger4(j);
+ j = j - m;
+ m = m >> 1;
+ adCount = adCount + 1;
+ }
+ pushInteger4(adCount);
+ pushInteger4(j);
+ j = j + m;
+ }
+ for (size_t i = 2 * N - (2 * N - 2) % 2 - 1; i >= 1; i -= 2) {
+ size_t m;
+ int adCount;
+ size_t i0;
+ popInteger4(&j);
+ popInteger4(&adCount);
+ for (i0 = 1; i0 < adCount + 1; ++i0)
+ popInteger4(&j);
+ popControl1b(&branch);
+ if (branch == 0) {
+ popReal8(&(data[j]));
+ popReal8(&(data[i]));
+ swap_b(&(data[j]), &(datab[j]), &(data[i]), &(datab[i]));
+ popReal8(&(data[j - 1]));
+ popReal8(&(data[i - 1]));
+ swap_b(&(data[j - 1]), &(datab[j - 1]), &(data[i - 1]), &(datab[i - 1]));
+ }
+ }
}
-static void rescale_nodiff(double *data, unsigned int N) {
- double scale = (double)1/N;
- for (int i = 0; i < 2*N; ++i)
- data[i] *= scale;
+static void scramble_c(double *data, size_t N) {
+ size_t j = 1;
+ for (size_t ii = 0; ii < N; ii++) {
+ size_t i = 2 * ii + 1;
+ if (j > i) {
+ swap_c(&(data[j - 1]), &(data[i - 1]));
+ swap_c(&(data[j]), &(data[i]));
+ }
+ size_t m = N;
+ while (m >= 2 && j > m) {
+ j -= m;
+ m >>= 1;
+ }
+ j += m;
+ }
}
/*
- Differentiation of fft in reverse (adjoint) mode (with options context):
+ Differentiation of rescale in reverse (adjoint) mode:
gradient of useful results: *data
with respect to varying inputs: *data
Plus diff mem management of: data:in
*/
-static void fft_b(double *data, double *datab, unsigned int N) {
- recursiveApply_b(data, datab, 1, N);
- scramble_b(data, datab, N);
+static void rescale_b(double *data, double *datab, size_t N) {
+ double scale = (double)1 / N;
+ for (size_t i = 0; i < 2 * N; ++i)
+ data[i] = data[i] * scale;
+ for (size_t i = 2 * N - 1; i > -1; --i)
+ datab[i] = scale * datab[i];
}
-static void fft_nodiff(double *data, unsigned int N) {
- scramble_nodiff(data, N);
- recursiveApply_nodiff(data, 1, N);
+static void rescale_c(double *data, size_t N) {
+ double scale = (double)1 / N;
+ for (size_t i = 0; i < 2 * N; ++i)
+ data[i] *= scale;
}
/*
- Differentiation of ifft in reverse (adjoint) mode (with options context):
+ Differentiation of fiveft in reverse (adjoint) mode:
gradient of useful results: *data
with respect to varying inputs: *data
Plus diff mem management of: data:in
*/
-static void ifft_b(double *data, double *datab, unsigned int N) {
- rescale_b(data, datab, N);
- recursiveApply_b(data, datab, -1, N);
- scramble_b(data, datab, N);
+void fiveft_b(double *data, double *datab, size_t N) {
+ pushReal8(*data);
+ scramble_c(data, N);
+ pushReal8(*data);
+ recursiveApply_c(data, 1, N);
+ popReal8(data);
+ recursiveApply_b(data, datab, 1, N);
+ popReal8(data);
+ scramble_b(data, datab, N);
}
-static void ifft_nodiff(double *data, unsigned int N) {
- scramble_nodiff(data, N);
- recursiveApply_nodiff(data, -1, N);
- rescale_nodiff(data, N);
+void fiveft_c(double *data, size_t N) {
+ scramble_c(data, N);
+ recursiveApply_c(data, 1, N);
}
/*
- Differentiation of foobar in reverse (adjoint) mode (with options context):
+ Differentiation of ifiveft in reverse (adjoint) mode:
gradient of useful results: *data
with respect to varying inputs: *data
- RW status of diff variables: *data:in-out
Plus diff mem management of: data:in
*/
-void foobar_b(double *data, double *datab, unsigned int len) {
- double chksum = 0.0;
- int i;
- ifft_b(data, datab, len);
- fft_b(data, datab, len);
+void ifiveft_b(double *data, double *datab, size_t N) {
+ pushReal8(*data);
+ scramble_c(data, N);
+ pushReal8(*data);
+ recursiveApply_c(data, -1, N);
+ pushReal8(*data);
+ rescale_c(data, N);
+ popReal8(data);
+ rescale_b(data, datab, N);
+ popReal8(data);
+ recursiveApply_b(data, datab, -1, N);
+ popReal8(data);
+ scramble_b(data, datab, N);
}
+void ifiveft_c(double *data, size_t N) {
+ scramble_c(data, N);
+ recursiveApply_c(data, -1, N);
+ rescale_c(data, N);
}
+/*
+ Differentiation of foobar in reverse (adjoint) mode:
+ gradient of useful results: *data
+ with respect to varying inputs: *data
+ RW status of diff variables: data:(loc) *data:in-out
+ Plus diff mem management of: data:in
+*/
+void foobar_b(double *data, double *datab, size_t len) {
+ pushReal8(*data);
+ fiveft_c(data, len);
+ pushReal8(*data);
+ ifiveft_c(data, len);
+ popReal8(data);
+ ifiveft_b(data, datab, len);
+ popReal8(data);
+ fiveft_b(data, datab, len);
+}
+}
#endif /* _fft_h_ */
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/lib.rs b/enzyme/benchmarks/ReverseMode/fft/src/lib.rs
new file mode 100644
index 0000000..3b49cb6
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/lib.rs
@@ -0,0 +1,6 @@
+#![feature(slice_swap_unchecked)]
+#![feature(autodiff)]
+#![feature(slice_as_chunks)]
+
+pub mod safe;
+pub mod unsf;
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/main.rs b/enzyme/benchmarks/ReverseMode/fft/src/main.rs
new file mode 100644
index 0000000..5f76ad9
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/main.rs
@@ -0,0 +1,22 @@
+use core::mem;
+use fft::safe;//::dfoobar;
+use fft::unsf;//::dfoobar;
+
+fn main() {
+ let len = 16;
+ let mut data = vec![1.0; 2*len];
+ for i in 0..len {
+ data[i] = 2.0;
+ }
+ let mut data_d = vec![1.0; 2*len];
+
+ //unsafe {safe::rust_dfoobar(len, data.as_mut_ptr(), data_d.as_mut_ptr());}
+ //unsafe {safe::rust_foobar(len, data.as_mut_ptr());}
+ unsafe {unsf::unsafe_dfoobar(len, data.as_mut_ptr(), data_d.as_mut_ptr());}
+ unsafe {unsf::unsafe_foobar(len, data.as_mut_ptr());}
+
+ dbg!(&data_d);
+ dbg!(&data);
+ //mem::forget(data);
+ //mem::forget(data_d);
+}
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/safe.rs b/enzyme/benchmarks/ReverseMode/fft/src/safe.rs
new file mode 100644
index 0000000..cbca5ab
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/safe.rs
@@ -0,0 +1,104 @@
+use std::autodiff::autodiff;
+use std::f64::consts::PI;
+use std::slice;
+
+fn bitreversal_perm<T>(data: &mut [T]) {
+ let len = data.len() / 2;
+ let mut j = 1;
+
+ for i in (1..data.len()).step_by(2) {
+ if j > i {
+ //dbg!(&i, &j);
+ data.swap(j-1, i-1);
+ data.swap(j, i);
+ //unsafe {
+ // data.swap_unchecked(j - 1, i - 1);
+ // data.swap_unchecked(j, i);
+ //}
+ }
+
+ let mut m = len;
+ while m >= 2 && j > m {
+ j -= m;
+ m >>= 1;
+ }
+
+ j += m;
+ }
+}
+
+fn radix2(data: &mut [f64], i_sign: i32) {
+ let n = data.len() / 2;
+ if n == 1 {
+ return;
+ }
+
+ let (a, b) = data.split_at_mut(n);
+ // assert_eq!(a.len(), b.len());
+ radix2(a, i_sign);
+ radix2(b, i_sign);
+
+ let wtemp = i_sign as f64 * (PI / n as f64).sin();
+ let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin();
+ let wpr = -2.0 * wtemp * wtemp;
+ let mut wr = 1.0;
+ let mut wi = 0.0;
+
+ let (achunks, _) = a.as_chunks_mut();
+ let (bchunks, _) = b.as_chunks_mut();
+ for ([ax, ay], [bx, by]) in achunks.iter_mut().zip(bchunks.iter_mut()) {
+ let tempr = *bx * wr - *by * wi;
+ let tempi = *bx * wi + *by * wr;
+
+ *bx = *ax - tempr;
+ *by = *ay - tempi;
+ *ax += tempr;
+ *ay += tempi;
+
+ let wtemp_new = wr;
+ wr = wr * (wpr + 1.0) - wi * wpi;
+ wi = wi * (wpr + 1.0) + wtemp_new * wpi;
+ }
+}
+
+fn rescale(data: &mut [f64], scale: usize) {
+ let scale = 1. / scale as f64;
+ for elm in data {
+ *elm *= scale;
+ }
+}
+
+fn fft(data: &mut [f64]) {
+ bitreversal_perm(data);
+ radix2(data, 1);
+}
+
+fn ifft(data: &mut [f64]) {
+ bitreversal_perm(data);
+ radix2(data, -1);
+ rescale(data, data.len() / 2);
+}
+
+#[autodiff(dfoobar, Reverse, DuplicatedOnly)]
+pub fn foobar(data: &mut [f64]) {
+ fft(data);
+ ifft(data);
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) {
+ let (data, ddata) = unsafe {
+ (
+ slice::from_raw_parts_mut(data, n * 2),
+ slice::from_raw_parts_mut(ddata, n * 2),
+ )
+ };
+
+ unsafe { dfoobar(data, ddata) };
+}
+
+#[no_mangle]
+pub extern "C" fn rust_foobar(n: usize, data: *mut f64) {
+ let data = unsafe { slice::from_raw_parts_mut(data, n * 2) };
+ foobar(data);
+}
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs b/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs
new file mode 100644
index 0000000..29c8ceb
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs
@@ -0,0 +1,92 @@
+use std::autodiff::autodiff;
+use std::f64::consts::PI;
+
+unsafe fn bitreversal_perm(data: *mut f64, len: usize) {
+ let mut j = 1;
+
+ for i in (1..2 * len).step_by(2) {
+ if j > i {
+ std::ptr::swap(data.add(j - 1), data.add(i - 1));
+ std::ptr::swap(data.add(j), data.add(i));
+ }
+
+ let mut m = len;
+ while m >= 2 && j > m {
+ j -= m;
+ m >>= 1;
+ }
+
+ j += m;
+ }
+}
+
+unsafe fn radix2(data: *mut f64, n: usize, i_sign: i32) {
+ if n == 1 {
+ return;
+ }
+ radix2(data, n / 2, i_sign);
+ radix2(data.add(n), n / 2, i_sign);
+
+ let wtemp = i_sign as f64 * (PI / n as f64).sin();
+ let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin();
+ let wpr = -2.0 * wtemp * wtemp;
+ let mut wr = 1.0;
+ let mut wi = 0.0;
+
+ for i in (0..n).step_by(2) {
+ let in_n = i + n;
+ let ax = &mut *data.add(i);
+ let ay = &mut *data.add(i + 1);
+ let bx = &mut *data.add(in_n);
+ let by = &mut *data.add(in_n + 1);
+ let tempr = *bx * wr - *by * wi;
+ let tempi = *bx * wi + *by * wr;
+
+ *bx = *ax - tempr;
+ *by = *ay - tempi;
+ *ax += tempr;
+ *ay += tempi;
+
+ let wtemp_new = wr;
+ wr = wr * (wpr + 1.0) - wi * wpi;
+ wi = wi * (wpr + 1.0) + wtemp_new * wpi;
+ }
+}
+
+unsafe fn rescale(data: *mut f64, n: usize) {
+ let scale = 1. / n as f64;
+ for i in 0..2 * n {
+ *data.add(i) = *data.add(i) * scale;
+ }
+}
+
+unsafe fn fft(data: *mut f64, n: usize) {
+ bitreversal_perm(data, n);
+ radix2(data, n, 1);
+}
+
+unsafe fn ifft(data: *mut f64, n: usize) {
+ bitreversal_perm(data, n);
+ radix2(data, n, -1);
+ rescale(data, n);
+}
+
+#[autodiff(unsafe_dfoobar, Reverse, Const, DuplicatedOnly)]
+pub unsafe fn unsafe_foobar(n: usize, data: *mut f64) {
+ fft(data, n);
+ ifft(data, n);
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) {
+ unsafe {
+ unsafe_dfoobar(n, data, ddata);
+ }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_foobar(n: usize, data: *mut f64) {
+ unsafe {
+ unsafe_foobar(n, data);
+ }
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock
new file mode 100644
index 0000000..cfdab95
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock
@@ -0,0 +1,16 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "gmmrs"
+version = "0.1.0"
+dependencies = [
+ "libm",
+]
+
+[[package]]
+name = "libm"
+version = "0.2.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml
new file mode 100644
index 0000000..1ae0273
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml
@@ -0,0 +1,26 @@
+[package]
+name = "gmmrs"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[lib]
+crate-type = ["lib"]
+
+[features]
+libm = ["dep:libm"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+panic = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
+
+[dependencies]
+libm = { version = "0.2.8", optional = true }
diff --git a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make
index 1e8e711..17e22dd 100644
--- a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make
@@ -6,6 +6,10 @@
clean:
rm -f *.ll *.o results.txt results.json
+ cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a: src/lib.rs Cargo.toml
+ RUSTFLAGS="-Z autodiff=Enable,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm
%-unopt.ll: %.cpp
clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,9 +20,9 @@
%-opt.ll: %-raw.ll
opt $^ -o $@ -S
-gmm.o: gmm-opt.ll
+gmm.o: gmm-opt.ll $(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a
clang++ -pthread -O2 $^ -o $@ $(BENCHLINK) -lm
#clang++ $(LOADCLANG) $(BENCH) gmm.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o gmm.o -lpthread $(BENCHLINK) -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
results.json: gmm.o
- ./$^
+ numactl -C 1 ./$^
diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp
index 8660592..cb7e864 100644
--- a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp
+++ b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp
@@ -13,7 +13,7 @@
* typedef struct
* {
* double gamma;
- * int m;
+ * size_t m;
* } Wishart;
*
* After Tapenade CLI installing use the next command to generate a file:
@@ -39,9 +39,9 @@
/* ==================================================================== */
// This throws error on n<1
-double arr_max(int n, double const* x)
+double arr_max(size_t n, double const* x)
{
- int i;
+ size_t i;
double m = x[0];
for (i = 1; i < n; i++)
{
@@ -57,9 +57,9 @@
// sum of component squares
-double sqnorm(int n, double const* x)
+double sqnorm(size_t n, double const* x)
{
- int i;
+ size_t i;
double res = x[0] * x[0];
for (i = 1; i < n; i++)
{
@@ -73,13 +73,13 @@
// out = a - b
void subtract(
- int d,
+ size_t d,
double const* x,
double const* y,
double* out
)
{
- int id;
+ size_t id;
for (id = 0; id < d; id++)
{
out[id] = x[id] - y[id];
@@ -87,9 +87,9 @@
}
-double log_sum_exp(int n, double const* x)
+double log_sum_exp(size_t n, double const* x)
{
- int i;
+ size_t i;
double mx = arr_max(n, x);
double semx = 0.0;
@@ -105,7 +105,7 @@
__attribute__((const))
double log_gamma_distrib(double a, double p)
{
- int j;
+ int64_t j;
double out = 0.25 * p * (p - 1) * log(PI);
for (j = 1; j <= p; j++)
@@ -123,17 +123,17 @@
/* ======================================================================== */
double log_wishart_prior(
- int p,
- int k,
+ size_t p,
+ size_t k,
Wishart wishart,
double const* sum_qs,
double const* Qdiags,
double const* icf
)
{
- int ik;
- int n = p + wishart.m + 1;
- int icf_sz = p * (p + 1) / 2;
+ size_t ik;
+ size_t n = p + wishart.m + 1;
+ size_t icf_sz = p * (p + 1) / 2;
double C = n * p * (log(wishart.gamma) - 0.5 * log(2)) - log_gamma_distrib(0.5 * n, p);
@@ -150,15 +150,15 @@
void preprocess_qs(
- int d,
- int k,
+ size_t d,
+ size_t k,
double const* icf,
double* sum_qs,
double* Qdiags
)
{
- int ik, id;
- int icf_sz = d * (d + 1) / 2;
+ size_t ik, id;
+ size_t icf_sz = d * (d + 1) / 2;
for (ik = 0; ik < k; ik++)
{
sum_qs[ik] = 0.;
@@ -174,14 +174,14 @@
void Qtimesx(
- int d,
+ size_t d,
double const* Qdiag,
double const* ltri, // strictly lower triangular part
double const* x,
double* out
)
{
- int i, j;
+ size_t i, j;
for (i = 0; i < d; i++)
{
out[i] = Qdiag[i] * x[i];
@@ -189,10 +189,10 @@
//caching lparams as scev doesn't replicate index calculation
// todo note changing to strengthened form
- //int Lparamsidx = 0;
+ //size_t Lparamsidx = 0;
for (i = 0; i < d; i++)
{
- int Lparamsidx = i*(2*d-i-1)/2;
+ size_t Lparamsidx = i*(2*d-i-1)/2;
for (j = i + 1; j < d; j++)
{
// and this x
@@ -202,24 +202,15 @@
}
}
-
-
-void gmm_objective(
- int d,
- int k,
- int n,
- double const* __restrict alphas,
- double const* __restrict means,
- double const* __restrict icf,
- double const* __restrict x,
- Wishart wishart,
- double* __restrict err
-)
-{
- #define int int64_t
- int ix, ik;
- const double CONSTANT = -n * d * 0.5 * log(2 * PI);
- int icf_sz = d * (d + 1) / 2;
+void gmm_objective_restrict(size_t d, size_t k, size_t n,
+ double const *__restrict alphas,
+ double const *__restrict means,
+ double const *__restrict icf,
+ double const *__restrict x, Wishart wishart,
+ double *__restrict err) {
+ int64_t ix, ik;
+ const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
+ int64_t icf_sz = d * (d + 1) / 2;
double* Qdiags = (double*)malloc(d * k * sizeof(double));
double* sum_qs = (double*)malloc(k * sizeof(double));
@@ -256,7 +247,6 @@
free(xcentered);
free(Qxcentered);
free(main_term);
- #undef int
}
extern int enzyme_const;
@@ -265,23 +255,16 @@
void __enzyme_autodiff(...) noexcept;
// * tapenade -b -o gmm_tapenade -head "gmm_objective(err)/(alphas means icf)" gmm.c
-void dgmm_objective(int d, int k, int n, const double *alphas, double *
- alphasb, const double *means, double *meansb, const double *icf,
- double *icfb, const double *x, Wishart wishart, double *err, double *
- errb) {
- __enzyme_autodiff(
- gmm_objective,
- enzyme_const, d,
- enzyme_const, k,
- enzyme_const, n,
- enzyme_dup, alphas, alphasb,
- enzyme_dup, means, meansb,
- enzyme_dup, icf, icfb,
- enzyme_const, x,
- enzyme_const, wishart,
- enzyme_dupnoneed, err, errb);
+void dgmm_objective_restrict(size_t d, size_t k, size_t n, const double *alphas,
+ double *alphasb, const double *means,
+ double *meansb, const double *icf, double *icfb,
+ const double *x, Wishart wishart, double *err,
+ double *errb) {
+ __enzyme_autodiff(gmm_objective_restrict, enzyme_const, d, enzyme_const, k,
+ enzyme_const, n, enzyme_dup, alphas, alphasb, enzyme_dup,
+ means, meansb, enzyme_dup, icf, icfb, enzyme_const, x,
+ enzyme_const, wishart, enzyme_dupnoneed, err, errb);
}
-
}
@@ -300,20 +283,19 @@
UTILS
==================================================================== */
// This throws error on n<1
-void arr_max_b(int n, const double *x, double *xb, double arr_maxb) {
- int i;
+void arr_max_b(size_t n, const double *x, double *xb, double arr_maxb) {
double m = x[0];
double mb = 0.0;
int branch;
double arr_max;
- for (i = 1; i < n; ++i)
+ for (int64_t i = 1; i < n; ++i)
if (m < x[i]) {
m = x[i];
pushControl1b(1);
} else
pushControl1b(0);
mb = arr_maxb;
- for (i = n-1; i > 0; --i) {
+ for (int64_t i = (int64_t)n-1; i > 0; --i) {
popControl1b(&branch);
if (branch != 0) {
xb[i] = xb[i] + mb;
@@ -327,8 +309,8 @@
UTILS
==================================================================== */
// This throws error on n<1
-double arr_max_nodiff(int n, const double *x) {
- int i;
+double arr_max_nodiff(size_t n, const double *x) {
+ size_t i;
double m = x[0];
for (i = 1; i < n; ++i)
if (m < x[i])
@@ -343,20 +325,19 @@
Plus diff mem management of: x:in
*/
// sum of component squares
-void sqnorm_b(int n, const double *x, double *xb, double sqnormb) {
- int i;
+void sqnorm_b(size_t n, const double *x, double *xb, double sqnormb) {
double res = x[0]*x[0];
double resb = 0.0;
double sqnorm;
resb = sqnormb;
- for (i = n-1; i > 0; --i)
+ for (int64_t i = (int64_t)n-1; i > 0; --i)
xb[i] = xb[i] + 2*x[i]*resb;
xb[0] = xb[0] + 2*x[0]*resb;
}
// sum of component squares
-double sqnorm_nodiff(int n, const double *x) {
- int i;
+double sqnorm_nodiff(size_t n, const double *x) {
+ size_t i;
double res = x[0]*x[0];
for (i = 1; i < n; ++i)
res = res + x[i]*x[i];
@@ -370,18 +351,17 @@
Plus diff mem management of: out:in y:in
*/
// out = a - b
-void subtract_b(int d, const double *x, const double *y, double *yb, double *
+void subtract_b(size_t d, const double *x, const double *y, double *yb, double *
out, double *outb) {
- int id;
- for (id = d-1; id > -1; --id) {
+ for (int64_t id = (int64_t)d-1; id > -1; --id) {
yb[id] = yb[id] - outb[id];
outb[id] = 0.0;
}
}
// out = a - b
-void subtract_nodiff(int d, const double *x, const double *y, double *out) {
- int id;
+void subtract_nodiff(size_t d, const double *x, const double *y, double *out) {
+ size_t id;
for (id = 0; id < d; ++id)
out[id] = x[id] - y[id];
}
@@ -392,8 +372,7 @@
with respect to varying inputs: *x
Plus diff mem management of: x:in
*/
-void log_sum_exp_b(int n, const double *x, double *xb, double log_sum_expb) {
- int i;
+void log_sum_exp_b(size_t n, const double *x, double *xb, double log_sum_expb) {
double mx;
double mxb;
double tempb;
@@ -401,11 +380,11 @@
mx = arr_max_nodiff(n, x);
double semx = 0.0;
double semxb = 0.0;
- for (i = 0; i < n; ++i)
+ for (int64_t i = 0; i < n; ++i)
semx = semx + exp(x[i] - mx);
semxb = log_sum_expb/semx;
mxb = log_sum_expb;
- for (i = n-1; i > -1; --i) {
+ for (int64_t i = (int64_t)n-1; i > -1; --i) {
tempb = exp(x[i]-mx)*semxb;
xb[i] = xb[i] + tempb;
mxb = mxb - tempb;
@@ -413,8 +392,8 @@
arr_max_b(n, x, xb, mxb);
}
-double log_sum_exp_nodiff(int n, const double *x) {
- int i;
+double log_sum_exp_nodiff(size_t n, const double *x) {
+ size_t i;
double mx;
mx = arr_max_nodiff(n, x);
double semx = 0.0;
@@ -424,7 +403,7 @@
}
double log_gamma_distrib_nodiff(double a, double p) {
- int j;
+ size_t j;
/* TFIX */
double out = 0.25*p*(p-1)*log(PI);
double arg1;
@@ -446,12 +425,12 @@
========================================================================
MAIN LOGIC
======================================================================== */
-void log_wishart_prior_b(int p, int k, Wishart wishart, const double *sum_qs,
+void log_wishart_prior_b(size_t p, size_t k, Wishart wishart, const double *sum_qs,
double *sum_qsb, const double *Qdiags, double *Qdiagsb, const double *
icf, double *icfb, double log_wishart_priorb) {
- int ik;
- int n = p + wishart.m + 1;
- int icf_sz = p*(p+1)/2;
+ int64_t ik;
+ size_t n = p + wishart.m + 1;
+ size_t icf_sz = p*(p+1)/2;
double C;
float arg1;
double result1;
@@ -461,7 +440,7 @@
for (ik = 0; ik < k; ++ik) {
double frobenius;
double result1;
- int arg1;
+ size_t arg1;
double result2;
}
outb = log_wishart_priorb;
@@ -471,12 +450,12 @@
sum_qsb[ik] = 0.0;
for (ik = 0; ik < k * icf_sz; ik++) /* TFIX */
icfb[ik] = 0.0;
- for (ik = k-1; ik > -1; --ik) {
+ for (ik = (int64_t)k-1; ik > -1; --ik) {
double frobenius;
double frobeniusb;
double result1;
double result1b;
- int arg1;
+ size_t arg1;
double result2;
double result2b;
frobeniusb = wishart.gamma*wishart.gamma*0.5*outb;
@@ -493,11 +472,11 @@
/* ========================================================================
MAIN LOGIC
======================================================================== */
-double log_wishart_prior_nodiff(int p, int k, Wishart wishart, const double *
+double log_wishart_prior_nodiff(size_t p, size_t k, Wishart wishart, const double *
sum_qs, const double *Qdiags, const double *icf) {
- int ik;
- int n = p + wishart.m + 1;
- int icf_sz = p*(p+1)/2;
+ size_t ik;
+ size_t n = p + wishart.m + 1;
+ size_t icf_sz = p*(p+1)/2;
double C;
float arg1;
double result1;
@@ -508,7 +487,7 @@
for (ik = 0; ik < k; ++ik) {
double frobenius;
double result1;
- int arg1;
+ size_t arg1;
double result2;
result1 = sqnorm_nodiff(p, &(Qdiags[ik*p]));
arg1 = icf_sz - p;
@@ -526,17 +505,17 @@
with respect to varying inputs: *icf
Plus diff mem management of: Qdiags:in sum_qs:in icf:in
*/
-void preprocess_qs_b(int d, int k, const double *icf, double *icfb, double *
+void preprocess_qs_b(size_t d, size_t k, const double *icf, double *icfb, double *
sum_qs, double *sum_qsb, double *Qdiags, double *Qdiagsb) {
- int ik, id;
- int icf_sz = d*(d+1)/2;
+ int64_t ik, id;
+ size_t icf_sz = d*(d+1)/2;
for (ik = 0; ik < k; ++ik)
for (id = 0; id < d; ++id) {
double q = icf[ik*icf_sz + id];
pushReal8(q);
}
- for (ik = k-1; ik > -1; --ik) {
- for (id = d-1; id > -1; --id) {
+ for (ik = (int64_t)k-1; ik > -1; --ik) {
+ for (id = (int64_t)d-1; id > -1; --id) {
double q;
double qb = 0.0;
popReal8(&q);
@@ -549,13 +528,12 @@
}
}
-void preprocess_qs_nodiff(int d, int k, const double *icf, double *sum_qs,
+void preprocess_qs_nodiff(size_t d, size_t k, const double *icf, double *sum_qs,
double *Qdiags) {
- int ik, id;
- int icf_sz = d*(d+1)/2;
- for (ik = 0; ik < k; ++ik) {
+ size_t icf_sz = d*(d+1)/2;
+ for (size_t ik = 0; ik < k; ++ik) {
sum_qs[ik] = 0.;
- for (id = 0; id < d; ++id) {
+ for (size_t id = 0; id < d; ++id) {
double q = icf[ik*icf_sz + id];
sum_qs[ik] = sum_qs[ik] + q;
Qdiags[ik*d + id] = exp(q);
@@ -569,41 +547,41 @@
with respect to varying inputs: *out *Qdiag *x *ltri
Plus diff mem management of: out:in Qdiag:in x:in ltri:in
*/
-void Qtimesx_b(int d, const double *Qdiag, double *Qdiagb, const double *ltri,
+void Qtimesx_b(size_t d, const double *Qdiag, double *Qdiagb, const double *ltri,
double *ltrib, const double *x, double *xb, double *out, double *outb)
{
// strictly lower triangular part
- int i, j;
+ int64_t i, j;
int adFrom;
- int Lparamsidx = 0;
+ size_t Lparamsidx = 0;
for (i = 0; i < d; ++i) {
adFrom = i + 1;
for (j = adFrom; j < d; ++j)
Lparamsidx++;
pushInteger4(adFrom);
}
- for (i = d-1; i > -1; --i) {
+ for (i = (int64_t)d-1; i > -1; --i) {
popInteger4(&adFrom);
- for (j = d-1; j > adFrom-1; --j) {
+ for (j = (int64_t)d-1; j > adFrom-1; --j) {
--Lparamsidx;
ltrib[Lparamsidx] = ltrib[Lparamsidx] + x[i]*outb[j];
xb[i] = xb[i] + ltri[Lparamsidx]*outb[j];
}
}
- for (i = d-1; i > -1; --i) {
+ for (i = (int64_t)d-1; i > -1; --i) {
Qdiagb[i] = Qdiagb[i] + x[i]*outb[i];
xb[i] = xb[i] + Qdiag[i]*outb[i];
outb[i] = 0.0;
}
}
-void Qtimesx_nodiff(int d, const double *Qdiag, const double *ltri, const
+void Qtimesx_nodiff(size_t d, const double *Qdiag, const double *ltri, const
double *x, double *out) {
// strictly lower triangular part
- int i, j;
+ size_t i, j;
for (i = 0; i < d; ++i)
out[i] = Qdiag[i]*x[i];
- int Lparamsidx = 0;
+ size_t Lparamsidx = 0;
for (i = 0; i < d; ++i)
for (j = i+1; j < d; ++j) {
out[j] = out[j] + ltri[Lparamsidx]*x[i];
@@ -619,19 +597,19 @@
*alphas:out
Plus diff mem management of: err:in means:in icf:in alphas:in
*/
-void gmm_objective_b(int d, int k, int n, const double *alphas, double *
+void gmm_objective_b(size_t d, size_t k, size_t n, const double *alphas, double *
alphasb, const double *means, double *meansb, const double *icf,
double *icfb, const double *x, Wishart wishart, double *err, double *
errb) {
- int ix, ik;
+ int64_t ix, ik;
/* TFIX */
- const double CONSTANT = -n*d*0.5*log(2*PI);
- int icf_sz = d*(d+1)/2;
+ const double CONSTANT = -(double)n*d*0.5*log(2*PI);
+ size_t icf_sz = d*(d+1)/2;
double *Qdiags;
double *Qdiagsb;
double result1;
double result1b;
- int ii1;
+ size_t ii1;
Qdiagsb = (double *)malloc(d*k*sizeof(double));
for (ii1 = 0; ii1 < d*k; ++ii1)
Qdiagsb[ii1] = 0.0;
@@ -687,10 +665,10 @@
log_sum_exp_b(k, alphas, alphasb, lse_alphasb);
for (ii1 = 0; ii1 < d * k; ii1++) /* TFIX */
meansb[ii1] = 0.0;
- for (ix = n-1; ix > -1; --ix) {
+ for (ix = (int64_t)n-1; ix > -1; --ix) {
result1b = slseb;
log_sum_exp_b(k, &(main_term[0]), &(main_termb[0]), result1b);
- for (ik = k-1; ik > -1; --ik) {
+ for (ik = (int64_t)k-1; ik > -1; --ik) {
popReal8(&(main_term[ik]));
alphasb[ik] = alphasb[ik] + main_termb[ik];
sum_qsb[ik] = sum_qsb[ik] + main_termb[ik];
@@ -733,32 +711,32 @@
// out = a - b
template<typename T1, typename T2, typename T3>
-void subtract(int d,
+void subtract(size_t d,
const T1* const x,
const T2* const y,
T3* out)
{
- for (int id = 0; id < d; id++)
+ for (size_t id = 0; id < d; id++)
{
out[id] = x[id] - y[id];
}
}
template<typename T>
-T sqnorm(int n, const T* const x)
+T sqnorm(size_t n, const T* const x)
{
T res = x[0] * x[0];
- for (int i = 1; i < n; i++)
+ for (size_t i = 1; i < n; i++)
res = res + x[i] * x[i];
return res;
}
// This throws error on n<1
template<typename T>
-T arr_max(int n, const T* const x)
+T arr_max(size_t n, const T* const x)
{
T m = x[0];
- for (int i = 1; i < n; i++)
+ for (size_t i = 1; i < n; i++)
{
if (m < x[i])
m = x[i];
@@ -767,12 +745,12 @@
}
template<typename T>
-void gmm_objective(int d, int k, int n, const T* const alphas, const T* const means,
+void gmm_objective(size_t d, size_t k, size_t n, const T* const alphas, const T* const means,
const T* const icf, const double* const x, Wishart wishart, T* err);
// split of the outer loop over points
template<typename T>
-void gmm_objective_split_inner(int d, int k,
+void gmm_objective_split_inner(size_t d, size_t k,
const T* const alphas,
const T* const means,
const T* const icf,
@@ -781,7 +759,7 @@
T* err);
// other terms which are outside the loop
template<typename T>
-void gmm_objective_split_other(int d, int k, int n,
+void gmm_objective_split_other(size_t d, size_t k, size_t n,
const T* const alphas,
const T* const means,
const T* const icf,
@@ -789,7 +767,7 @@
T* err);
template<typename T>
-T logsumexp(int n, const T* const x);
+T logsumexp(size_t n, const T* const x);
// p: dim
// k: number of components
@@ -798,20 +776,20 @@
// Qdiags: d*k
// icf: (p*(p+1)/2)*k inverse covariance factors
template<typename T>
-T log_wishart_prior(int p, int k,
+T log_wishart_prior(size_t p, size_t k,
Wishart wishart,
const T* const sum_qs,
const T* const Qdiags,
const T* const icf);
template<typename T>
-void preprocess_qs(int d, int k,
+void preprocess_qs(size_t d, size_t k,
const T* const icf,
T* sum_qs,
T* Qdiags);
template<typename T>
-void Qtimesx(int d,
+void Qtimesx(size_t d,
const T* const Qdiag,
const T* const ltri, // strictly lower triangular part
const T* const x,
@@ -822,11 +800,11 @@
////////////////////////////////////////////////////////////
template<typename T>
-T logsumexp(int n, const T* const x)
+T logsumexp(size_t n, const T* const x)
{
T mx = arr_max(n, x);
T semx = 0.;
- for (int i = 0; i < n; i++)
+ for (size_t i = 0; i < n; i++)
{
semx = semx + exp(x[i] - mx);
}
@@ -834,19 +812,19 @@
}
template<typename T>
-T log_wishart_prior(int p, int k,
+T log_wishart_prior(size_t p, size_t k,
Wishart wishart,
const T* const sum_qs,
const T* const Qdiags,
const T* const icf)
{
- int n = p + wishart.m + 1;
- int icf_sz = p * (p + 1) / 2;
+ size_t n = p + wishart.m + 1;
+ size_t icf_sz = p * (p + 1) / 2;
double C = n * p * (log(wishart.gamma) - 0.5 * log(2)) - log_gamma_distrib(0.5 * n, p);
T out = 0;
- for (int ik = 0; ik < k; ik++)
+ for (size_t ik = 0; ik < k; ik++)
{
T frobenius = sqnorm(p, &Qdiags[ik * p]) + sqnorm(icf_sz - p, &icf[ik * icf_sz + p]);
out = out + 0.5 * wishart.gamma * wishart.gamma * (frobenius)
@@ -857,16 +835,16 @@
}
template<typename T>
-void preprocess_qs(int d, int k,
+void preprocess_qs(size_t d, size_t k,
const T* const icf,
T* sum_qs,
T* Qdiags)
{
- int icf_sz = d * (d + 1) / 2;
- for (int ik = 0; ik < k; ik++)
+ size_t icf_sz = d * (d + 1) / 2;
+ for (size_t ik = 0; ik < k; ik++)
{
sum_qs[ik] = 0.;
- for (int id = 0; id < d; id++)
+ for (size_t id = 0; id < d; id++)
{
T q = icf[ik * icf_sz + id];
sum_qs[ik] = sum_qs[ik] + q;
@@ -876,19 +854,19 @@
}
template<typename T>
-void Qtimesx(int d,
+void Qtimesx(size_t d,
const T* const Qdiag,
const T* const ltri, // strictly lower triangular part
const T* const x,
T* out)
{
- for (int id = 0; id < d; id++)
+ for (size_t id = 0; id < d; id++)
out[id] = Qdiag[id] * x[id];
- int Lparamsidx = 0;
- for (int i = 0; i < d; i++)
+ size_t Lparamsidx = 0;
+ for (size_t i = 0; i < d; i++)
{
- for (int j = i + 1; j < d; j++)
+ for (size_t j = i + 1; j < d; j++)
{
out[j] = out[j] + ltri[Lparamsidx] * x[i];
Lparamsidx++;
@@ -897,7 +875,7 @@
}
template<typename T>
-void gmm_objective(int d, int k, int n,
+void gmm_objective(size_t d, size_t k, size_t n,
const T* const alphas,
const T* const means,
const T* const icf,
@@ -905,8 +883,8 @@
Wishart wishart,
T* err)
{
- const double CONSTANT = -n * d * 0.5 * log(2 * PI);
- int icf_sz = d * (d + 1) / 2;
+ const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
+ size_t icf_sz = d * (d + 1) / 2;
vector<T> Qdiags(d * k);
vector<T> sum_qs(k);
@@ -917,9 +895,9 @@
preprocess_qs(d, k, icf, &sum_qs[0], &Qdiags[0]);
T slse = 0.;
- for (int ix = 0; ix < n; ix++)
+ for (size_t ix = 0; ix < n; ix++)
{
- for (int ik = 0; ik < k; ik++)
+ for (size_t ik = 0; ik < k; ik++)
{
subtract(d, &x[ix * d], &means[ik * d], &xcentered[0]);
Qtimesx(d, &Qdiags[ik * d], &icf[ik * icf_sz + d], &xcentered[0], &Qxcentered[0]);
@@ -937,7 +915,7 @@
}
template<typename T>
-void gmm_objective_split_inner(int d, int k,
+void gmm_objective_split_inner(size_t d, size_t k,
const T* const alphas,
const T* const means,
const T* const icf,
@@ -945,39 +923,39 @@
Wishart wishart,
T* err)
{
- int icf_sz = d * (d + 1) / 2;
+ size_t icf_sz = d * (d + 1) / 2;
T* Ldiag = new T[d];
T* xcentered = new T[d];
T* mahal = new T[d];
T* lse = new T[k];
- for (int ik = 0; ik < k; ik++)
+ for (size_t ik = 0; ik < k; ik++)
{
- int icf_off = ik * icf_sz;
+ size_t icf_off = ik * icf_sz;
T sumlog_Ldiag(0.);
- for (int id = 0; id < d; id++)
+ for (size_t id = 0; id < d; id++)
{
sumlog_Ldiag = sumlog_Ldiag + icf[icf_off + id];
Ldiag[id] = exp(icf[icf_off + id]);
}
- for (int id = 0; id < d; id++)
+ for (size_t id = 0; id < d; id++)
{
xcentered[id] = x[id] - means[ik * d + id];
mahal[id] = Ldiag[id] * xcentered[id];
}
- int Lparamsidx = d;
- for (int i = 0; i < d; i++)
+ size_t Lparamsidx = d;
+ for (size_t i = 0; i < d; i++)
{
- for (int j = i + 1; j < d; j++)
+ for (size_t j = i + 1; j < d; j++)
{
mahal[j] = mahal[j] + icf[icf_off + Lparamsidx] * xcentered[i];
Lparamsidx++;
}
}
T sqsum_mahal(0.);
- for (int id = 0; id < d; id++)
+ for (size_t id = 0; id < d; id++)
{
sqsum_mahal = sqsum_mahal + mahal[id] * mahal[id];
}
@@ -994,14 +972,14 @@
}
template<typename T>
-void gmm_objective_split_other(int d, int k, int n,
+void gmm_objective_split_other(size_t d, size_t k, size_t n,
const T* const alphas,
const T* const means,
const T* const icf,
Wishart wishart,
T* err)
{
- const double CONSTANT = -n * d * 0.5 * log(2 * PI);
+ const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
T lse_alphas = logsumexp(k, alphas);
@@ -1015,14 +993,14 @@
};
-void adept_dgmm_objective(int d, int k, int n, const double *alphas, double *
+void adept_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double *
alphasb, const double *means, double *meansb, const double *icf,
double *icfb, const double *x, Wishart wishart, double *err, double *
errb) {
- int icf_sz = d*(d + 1) / 2;
- int Jrows = 1;
- int Jcols = (k*(d + 1)*(d + 2)) / 2;
+ size_t icf_sz = d*(d + 1) / 2;
+ size_t Jrows = 1;
+ size_t Jcols = (k*(d + 1)*(d + 2)) / 2;
adept::Stack stack;
adouble *aalphas = new adouble[k];
@@ -1050,3 +1028,5 @@
delete[] ameans;
delete[] aicf;
}
+
+#include "gmm_mayalias.h"
diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm.h b/enzyme/benchmarks/ReverseMode/gmm/gmm.h
index eb189af..5dc5bc0 100644
--- a/enzyme/benchmarks/ReverseMode/gmm/gmm.h
+++ b/enzyme/benchmarks/ReverseMode/gmm/gmm.h
@@ -28,9 +28,9 @@
// wishart: wishart distribution parameters
// err: 1 output
void gmm_objective(
- int d,
- int k,
- int n,
+ size_t d,
+ size_t k,
+ size_t n,
double const* alphas,
double const* means,
double const* icf,
diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h b/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h
new file mode 100644
index 0000000..4bcba4f
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h
@@ -0,0 +1,62 @@
+void gmm_objective(size_t d, size_t k, size_t n, double const *alphas,
+ double const *means, double const *icf, double const *x,
+ Wishart wishart, double *err) {
+ size_t ix, ik;
+ const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
+ size_t icf_sz = d * (d + 1) / 2;
+
+ double *Qdiags = (double *)malloc(d * k * sizeof(double));
+ double *sum_qs = (double *)malloc(k * sizeof(double));
+ double *xcentered = (double *)malloc(d * sizeof(double));
+ double *Qxcentered = (double *)malloc(d * sizeof(double));
+ double *main_term = (double *)malloc(k * sizeof(double));
+
+ preprocess_qs(d, k, icf, &sum_qs[0], &Qdiags[0]);
+
+ double slse = 0.;
+ for (ix = 0; ix < n; ix++) {
+ for (ik = 0; ik < k; ik++) {
+ subtract(d, &x[ix * d], &means[ik * d], &xcentered[0]);
+ Qtimesx(d, &Qdiags[ik * d], &icf[ik * icf_sz + d], &xcentered[0],
+ &Qxcentered[0]);
+ // two caches for qxcentered at idx 0 and at arbitrary index
+ main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]);
+ }
+
+ // storing cmp for max of main_term
+ // 2 x (0 and arbitrary) storing sub to exp
+ // storing sum for use in log
+ slse = slse + log_sum_exp(k, &main_term[0]);
+ }
+
+ // storing cmp of alphas
+ double lse_alphas = log_sum_exp(k, alphas);
+
+ *err = CONSTANT + slse - n * lse_alphas +
+ log_wishart_prior(d, k, wishart, &sum_qs[0], &Qdiags[0], icf);
+
+ free(Qdiags);
+ free(sum_qs);
+ free(xcentered);
+ free(Qxcentered);
+ free(main_term);
+}
+
+// * tapenade -b -o gmm_tapenade -head "gmm_objective(err)/(alphas means icf)" gmm.c
+void dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double *
+ alphasb, const double *means, double *meansb, const double *icf,
+ double *icfb, const double *x, Wishart wishart, double *err, double *
+ errb) {
+ __enzyme_autodiff(
+ gmm_objective,
+ enzyme_const, d,
+ enzyme_const, k,
+ enzyme_const, n,
+ enzyme_dup, alphas, alphasb,
+ enzyme_dup, means, meansb,
+ enzyme_dup, icf, icfb,
+ enzyme_const, x,
+ enzyme_const, wishart,
+ enzyme_dupnoneed, err, errb);
+}
+
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs
new file mode 100644
index 0000000..4f9fc53
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs
@@ -0,0 +1,10 @@
+#![feature(autodiff)]
+pub mod safe;
+pub mod r#unsafe;
+
+#[derive(Clone, Copy)]
+#[repr(C)]
+pub struct Wishart {
+ pub gamma: f64,
+ pub m: i32,
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/main.rs b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs
new file mode 100644
index 0000000..e7ebf74
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs
@@ -0,0 +1,24 @@
+#![feature(autodiff)]
+use gmmrs::{Wishart, r#unsafe::dgmm_objective};
+
+fn main() {
+ let d = 2;
+ let k = 2;
+ let n = 2;
+ let alphas = vec![0.5, 0.5];
+ let means = vec![0., 0., 1., 1.];
+ let icf = vec![1., 0., 1.];
+ let x = vec![0., 0., 1., 1.];
+ let wishart = Wishart { gamma: 1., m: 1 };
+ let mut err = 0.;
+ let mut d_alphas = vec![0.; alphas.len()];
+ let mut d_means = vec![0.; means.len()];
+ let mut d_icf = vec![0.; icf.len()];
+ let mut d_x = vec![0.; x.len()];
+ let mut d_err = 0.;
+ let mut err2 = &mut err;
+ let mut d_err2 = &mut d_err;
+ let wishart2 = &wishart;
+ // pass as raw ptr:
+ unsafe {dgmm_objective(d, k, n, alphas.as_ptr(), d_alphas.as_mut_ptr(), means.as_ptr(), d_means.as_mut_ptr(), icf.as_ptr(), d_icf.as_mut_ptr(), x.as_ptr(), wishart2 as *const Wishart, err2 as *mut f64, d_err2 as *mut f64);}
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs b/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs
new file mode 100644
index 0000000..9356b11
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs
@@ -0,0 +1,303 @@
+use crate::Wishart;
+use std::f64::consts::PI;
+use std::autodiff::autodiff;
+
+#[cfg(feature = "libm")]
+use libm::lgamma;
+
+#[cfg(not(feature = "libm"))]
+mod cmath {
+ extern "C" {
+ pub fn lgamma(x: f64) -> f64;
+ }
+}
+#[cfg(not(feature = "libm"))]
+#[inline]
+fn lgamma(x: f64) -> f64 {
+ unsafe { cmath::lgamma(x) }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dgmm_objective(
+ d: i32,
+ k: i32,
+ n: i32,
+ alphas: *const f64,
+ dalphas: *mut f64,
+ means: *const f64,
+ dmeans: *mut f64,
+ icf: *const f64,
+ dicf: *mut f64,
+ x: *const f64,
+ wishart: *const Wishart,
+ err: *mut f64,
+ derr: *mut f64,
+) {
+ let k = k as usize;
+ let n = n as usize;
+ let d = d as usize;
+ let alphas = unsafe { std::slice::from_raw_parts(alphas, k) };
+ let means = unsafe { std::slice::from_raw_parts(means, k * d) };
+ let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) };
+ let x = unsafe { std::slice::from_raw_parts(x, n * d) };
+ let wishart: Wishart = unsafe { *wishart };
+ let mut my_err = unsafe { *err };
+
+ let d_alphas = unsafe { std::slice::from_raw_parts_mut(dalphas, k) };
+ let d_means = unsafe { std::slice::from_raw_parts_mut(dmeans, k * d) };
+ let d_icf = unsafe { std::slice::from_raw_parts_mut(dicf, k * d * (d + 1) / 2) };
+ let mut my_derr = unsafe { *derr };
+ let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) =
+ get_workspace(d, k);
+ let (mut bqdiags, mut bsum_qs, mut bxcentered, mut bqxcentered, mut bmain_term) =
+ get_workspace(d, k);
+
+ unsafe { dgmm_objective(
+ d,
+ k,
+ n,
+ alphas,
+ d_alphas,
+ means,
+ d_means,
+ icf,
+ d_icf,
+ x,
+ wishart.gamma,
+ wishart.m,
+ &mut my_err,
+ &mut my_derr,
+ &mut qdiags,
+ &mut bqdiags,
+ &mut sum_qs,
+ &mut bsum_qs,
+ &mut xcentered,
+ &mut bxcentered,
+ &mut qxcentered,
+ &mut bqxcentered,
+ &mut main_term,
+ &mut bmain_term,
+ )};
+
+ unsafe { *err = my_err };
+ unsafe { *derr = my_derr };
+}
+
+#[no_mangle]
+pub extern "C" fn rust_gmm_objective(
+ d: i32,
+ k: i32,
+ n: i32,
+ alphas: *const f64,
+ means: *const f64,
+ icf: *const f64,
+ x: *const f64,
+ wishart: *const Wishart,
+ err: *mut f64,
+) {
+ let k = k as usize;
+ let n = n as usize;
+ let d = d as usize;
+ let alphas = unsafe { std::slice::from_raw_parts(alphas, k) };
+ let means = unsafe { std::slice::from_raw_parts(means, k * d) };
+ let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) };
+ let x = unsafe { std::slice::from_raw_parts(x, n * d) };
+ let wishart: Wishart = unsafe { *wishart };
+ let mut my_err = unsafe { *err };
+ let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) =
+ get_workspace(d, k);
+ gmm_objective(
+ d,
+ k,
+ n,
+ alphas,
+ means,
+ icf,
+ x,
+ wishart.gamma,
+ wishart.m,
+ &mut my_err,
+ &mut qdiags,
+ &mut sum_qs,
+ &mut xcentered,
+ &mut qxcentered,
+ &mut main_term,
+ );
+ unsafe { *err = my_err };
+}
+
+fn get_workspace(d: usize, k: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
+ let qdiags = vec![0.; d * k];
+ let sum_qs = vec![0.; k];
+ let xcentered = vec![0.; d];
+ let qxcentered = vec![0.; d];
+ let main_term = vec![0.; k];
+ (qdiags, sum_qs, xcentered, qxcentered, main_term)
+}
+
+#[autodiff(
+ dgmm_objective,
+ Reverse,
+ Const,
+ Const,
+ Const,
+ Duplicated,
+ Duplicated,
+ Duplicated,
+ Const,
+ Const,
+ Const,
+ DuplicatedOnly,
+ Duplicated,
+ Duplicated,
+ Duplicated,
+ Duplicated,
+ Duplicated
+)]
+pub fn gmm_objective(
+ d: usize,
+ k: usize,
+ n: usize,
+ alphas: &[f64],
+ means: &[f64],
+ icf: &[f64],
+ x: &[f64],
+ gamma: f64,
+ m: i32,
+ err: &mut f64,
+ qdiags: &mut [f64],
+ sum_qs: &mut [f64],
+ xcentered: &mut [f64],
+ qxcentered: &mut [f64],
+ main_term: &mut [f64],
+) {
+ let wishart: Wishart = Wishart { gamma, m };
+ let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln();
+ let icf_sz = d * (d + 1) / 2;
+
+ // Let the compiler know sizes so it can eliminate bounds checks
+ assert_eq!(qdiags.len(), d * k);
+ assert_eq!(sum_qs.len(), k);
+ assert_eq!(xcentered.len(), d);
+ assert_eq!(qxcentered.len(), d);
+ assert_eq!(main_term.len(), k);
+
+ preprocess_qs(d, k, icf, sum_qs, qdiags);
+
+ let mut slse = 0.;
+ for ix in 0..n {
+ for ik in 0..k {
+ subtract(
+ d,
+ &x[ix as usize * d as usize..],
+ &means[ik as usize * d as usize..],
+ xcentered,
+ );
+ qtimesx(
+ d,
+ &qdiags[ik as usize * d as usize..],
+ &icf[ik as usize * icf_sz as usize + d as usize..],
+ &*xcentered,
+ qxcentered,
+ );
+ main_term[ik as usize] =
+ alphas[ik as usize] + sum_qs[ik as usize] - 0.5 * sqnorm(&*qxcentered);
+ }
+
+ slse = slse + log_sum_exp(k, &main_term);
+ }
+
+ let lse_alphas = log_sum_exp(k, alphas);
+
+ *err = constant + slse - n as f64 * lse_alphas
+ + log_wishart_prior(d, k, wishart, &sum_qs, &*qdiags, icf);
+}
+
+fn arr_max(n: usize, x: &[f64]) -> f64 {
+ let mut max = f64::NEG_INFINITY;
+ for i in 0..n {
+ if max < x[i] {
+ max = x[i];
+ }
+ }
+ max
+}
+
+fn preprocess_qs(d: usize, k: usize, icf: &[f64], sum_qs: &mut [f64], qdiags: &mut [f64]) {
+ let icf_sz = d * (d + 1) / 2;
+ for ik in 0..k {
+ sum_qs[ik as usize] = 0.;
+ for id in 0..d {
+ let q = icf[ik as usize * icf_sz as usize + id as usize];
+ sum_qs[ik as usize] = sum_qs[ik as usize] + q;
+ qdiags[ik as usize * d as usize + id as usize] = q.exp();
+ }
+ }
+}
+fn subtract(d: usize, x: &[f64], y: &[f64], out: &mut [f64]) {
+ assert!(x.len() >= d);
+ assert!(y.len() >= d);
+ assert!(out.len() >= d);
+ for i in 0..d {
+ out[i] = x[i] - y[i];
+ }
+}
+
+fn qtimesx(d: usize, q_diag: &[f64], ltri: &[f64], x: &[f64], out: &mut [f64]) {
+ assert!(out.len() >= d);
+ assert!(q_diag.len() >= d);
+ assert!(x.len() >= d);
+ for i in 0..d {
+ out[i] = q_diag[i] * x[i];
+ }
+
+ for i in 0..d {
+ let mut lparamsidx = i * (2 * d - i - 1) / 2;
+ for j in i + 1..d {
+ out[j] = out[j] + ltri[lparamsidx] * x[i];
+ lparamsidx += 1;
+ }
+ }
+}
+
+fn log_sum_exp(n: usize, x: &[f64]) -> f64 {
+ let mx = arr_max(n, x);
+ let semx: f64 = x.iter().map(|x| (x - mx).exp()).sum();
+ semx.ln() + mx
+}
+fn log_gamma_distrib(a: f64, p: f64) -> f64 {
+ 0.25 * p * (p - 1.) * PI.ln()
+ + (1..=p as usize)
+ .map(|j| lgamma(a + 0.5 * (1. - j as f64)))
+ .sum::<f64>()
+}
+
+fn log_wishart_prior(
+ p: usize,
+ k: usize,
+ wishart: Wishart,
+ sum_qs: &[f64],
+ qdiags: &[f64],
+ icf: &[f64],
+) -> f64 {
+ let n = p + wishart.m as usize + 1;
+ let icf_sz = p * (p + 1) / 2;
+
+ let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln())
+ - log_gamma_distrib(0.5 * n as f64, p as f64);
+
+ let out = (0..k)
+ .map(|ik| {
+ let frobenius = sqnorm(&qdiags[ik * p as usize..][..p])
+ + sqnorm(&icf[ik * icf_sz as usize + p as usize..][..icf_sz - p]);
+ 0.5 * wishart.gamma * wishart.gamma * (frobenius)
+ - (wishart.m as f64) * sum_qs[ik as usize]
+ })
+ .sum::<f64>();
+
+ out - k as f64 * c
+}
+
+fn sqnorm(x: &[f64]) -> f64 {
+ x.iter().map(|x| x * x).sum()
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs b/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs
new file mode 100644
index 0000000..aa91938
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs
@@ -0,0 +1,148 @@
+use std::f64::consts::PI;
+use crate::Wishart;
+use std::autodiff::autodiff;
+
+#[cfg(feature = "libm")]
+use libm::lgamma;
+
+#[cfg(not(feature = "libm"))]
+mod cmath {
+ extern "C" {
+ pub fn lgamma(x: f64) -> f64;
+ }
+}
+#[cfg(not(feature = "libm"))]
+#[inline]
+fn lgamma(x: f64) -> f64 {
+ unsafe { cmath::lgamma(x) }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_dgmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, dalphas: *mut f64, means: *const f64, dmeans: *mut f64, icf: *const f64, dicf: *mut f64, x: *const f64, wishart: *const Wishart, err: *mut f64, derr: *mut f64) {
+ let k = k as usize;
+ let n = n as usize;
+ let d = d as usize;
+ unsafe { dgmm_objective(d, k, n, alphas, dalphas, means, dmeans, icf, dicf, x, wishart, err, derr); }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_gmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) {
+ let k = k as usize;
+ let n = n as usize;
+ let d = d as usize;
+ unsafe {gmm_objective(d, k, n, alphas, means, icf, x, wishart, err); }
+}
+
+//#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, Duplicated)]
+//pub unsafe fn gmm_objective(d: usize, k: usize, n: usize, alphas: &[f64], means: &[f64], icf: &[f64], x: &[f64], gamma: f64, m: i32, err: &mut f64) {
+// gmm_objective(d, k, n, alphas, means, icf, x, wishart, &mut my_err);
+//}
+
+#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, DuplicatedOnly)]
+pub unsafe fn gmm_objective(d: usize, k: usize, n: usize, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) {
+ let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln();
+ let icf_sz = d * (d + 1) / 2;
+ let mut qdiags = vec![0.; d * k];
+ let mut sum_qs = vec![0.; k];
+ let mut xcentered = vec![0.; d];
+ let mut qxcentered = vec![0.; d];
+ let mut main_term = vec![0.; k];
+
+ preprocess_qs(d, k, icf, sum_qs.as_mut_ptr(), qdiags.as_mut_ptr());
+
+ let mut slse = 0.;
+ for ix in 0..n {
+ for ik in 0..k {
+ subtract(d, x.add(ix * d), means.add(ik * d), xcentered.as_mut_ptr());
+ qtimesx(d, qdiags.as_mut_ptr().add(ik * d), icf.add(ik * icf_sz + d), xcentered.as_ptr(), qxcentered.as_mut_ptr());
+ main_term[ik] = *alphas.add(ik) + sum_qs[ik] - 0.5 * sqnorm(d, qxcentered.as_ptr());
+ //main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]);
+ }
+
+ slse = slse + log_sum_exp(k, main_term.as_ptr());
+ }
+
+ let lse_alphas = log_sum_exp(k, alphas);
+
+ *err = constant + slse - n as f64 * lse_alphas + log_wishart_prior(d, k, *wishart, sum_qs.as_ptr(), qdiags.as_ptr(), icf);
+}
+
+unsafe fn arr_max(n: usize, x: *const f64) -> f64 {
+ let mut max = f64::NEG_INFINITY;
+ for i in 0..n {
+ if max < *x.add(i) {
+ max = *x.add(i);
+ }
+ }
+ max
+}
+
+unsafe fn preprocess_qs(d: usize, k: usize, icf: *const f64, sum_qs: *mut f64, qdiags: *mut f64) {
+ let icf_sz = d * (d + 1) / 2;
+ for ik in 0..k {
+ *sum_qs.add(ik) = 0.;
+ for id in 0..d {
+ let q = *icf.add(ik * icf_sz + id);
+ *sum_qs.add(ik) = *sum_qs.add(ik) + q;
+ *qdiags.add(ik * d + id) = q.exp();
+ }
+ }
+}
+
+unsafe fn subtract(d: usize, x: *const f64, y: *const f64, out: *mut f64) {
+ for i in 0..d {
+ *out.add(i) = *x.add(i) - *y.add(i);
+ }
+}
+
+unsafe fn qtimesx(d: usize, q_diag: *const f64, ltri: *const f64, x: *const f64, out: *mut f64) {
+ for i in 0..d {
+ *out.add(i) = *q_diag.add(i) * *x.add(i);
+ }
+
+ for i in 0..d {
+ let mut lparamsidx = i*(2*d-i-1)/2;
+ for j in i + 1..d {
+ *out.add(j) = *out.add(j) + *ltri.add(lparamsidx) * *x.add(i);
+ lparamsidx += 1;
+ }
+ }
+}
+
+unsafe fn log_sum_exp(n: usize, x: *const f64) -> f64 {
+ let mx = arr_max(n, x);
+ let mut semx: f64 = 0.0;
+
+ for i in 0..n {
+ semx = semx + (*x.add(i) - mx).exp();
+ }
+ semx.ln() + mx
+}
+
+fn log_gamma_distrib(a: f64, p: f64) -> f64 {
+ 0.25 * p * (p - 1.) * PI.ln() + (1..=p as usize).map(|j| lgamma(a + 0.5 * (1. - j as f64))).sum::<f64>()
+}
+
+unsafe fn log_wishart_prior(p: usize, k: usize, wishart: Wishart, sum_qs: *const f64, qdiags: *const f64, icf: *const f64) -> f64 {
+ let n = p + wishart.m as usize + 1;
+ let icf_sz = p * (p + 1) / 2;
+
+ let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) - log_gamma_distrib(0.5 * n as f64, p as f64);
+
+ let mut out = 0.;
+
+ for ik in 0..k {
+ let frobenius = sqnorm(p, qdiags.add(ik * p)) + sqnorm(icf_sz - p, icf.add(ik * icf_sz + p));
+ out = out + 0.5 * wishart.gamma * wishart.gamma * (frobenius) - wishart.m as f64 * *sum_qs.add(ik);
+ }
+
+ out - k as f64 * c
+}
+
+unsafe fn sqnorm(n: usize, x: *const f64) -> f64 {
+ let mut sum = 0.;
+ for i in 0..n {
+ sum += *x.add(i) * *x.add(i);
+ }
+ sum
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock b/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock
new file mode 100644
index 0000000..270bf43
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock
@@ -0,0 +1,7 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "lstm"
+version = "0.1.0"
diff --git a/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml b/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml
new file mode 100644
index 0000000..d28f845
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "lstm"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+
+[lib]
+crate-type = ["lib"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
diff --git a/enzyme/benchmarks/ReverseMode/lstm/Makefile.make b/enzyme/benchmarks/ReverseMode/lstm/Makefile.make
index 276c5df..1388a54 100644
--- a/enzyme/benchmarks/ReverseMode/lstm/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/lstm/Makefile.make
@@ -6,6 +6,10 @@
clean:
rm -f *.ll *.o results.txt results.json
+ cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a: src/lib.rs Cargo.toml
+ RUSTFLAGS="-Z autodiff=Enable,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib
%-unopt.ll: %.cpp
clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,8 +20,8 @@
%-opt.ll: %-raw.ll
opt $^ -o $@ -S
-lstm.o: lstm-opt.ll
+lstm.o: lstm-opt.ll $(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a
clang++ -pthread -O2 $^ -o $@ $(BENCHLINK) -lm
results.json: lstm.o
- ./$^
+ numactl -C 1 ./$^
diff --git a/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp b/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp
index dbbc992..ade0b22 100644
--- a/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp
+++ b/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp
@@ -50,15 +50,10 @@
// LSTM OBJECTIVE
// The LSTM model
-void lstm_model(
- int hsize,
- double const* __restrict weight,
- double const* __restrict bias,
- double* __restrict hidden,
- double* __restrict cell,
- double const* __restrict input
-)
-{
+void lstm_model_restrict(int hsize, double const *__restrict weight,
+ double const *__restrict bias,
+ double *__restrict hidden, double *__restrict cell,
+ double const *__restrict input) {
// TODO NOTE THIS
//__builtin_assume(hsize > 0);
@@ -94,16 +89,9 @@
}
// Predict LSTM output given an input
-void lstm_predict(
- int l,
- int b,
- double const* __restrict w,
- double const* __restrict w2,
- double* __restrict s,
- double const* __restrict x,
- double* __restrict x2
-)
-{
+void lstm_predict_restrict(int l, int b, double const *__restrict w,
+ double const *__restrict w2, double *__restrict s,
+ double const *__restrict x, double *__restrict x2) {
int i;
for (i = 0; i < b; i++)
{
@@ -113,7 +101,8 @@
double* xp = x2;
for (i = 0; i <= 2 * l * b - 1; i += 2 * b)
{
- lstm_model(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]), &(s[i + b]), xp);
+ lstm_model_restrict(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]),
+ &(s[i + b]), xp);
xp = &(s[i]);
}
@@ -124,17 +113,12 @@
}
// LSTM objective (loss function)
-void lstm_objective(
- int l,
- int c,
- int b,
- double const* __restrict main_params,
- double const* __restrict extra_params,
- double* __restrict state,
- double const* __restrict sequence,
- double* __restrict loss
-)
-{
+void cxx_restrict_lstm_objective(int l, int c, int b,
+ double const *__restrict main_params,
+ double const *__restrict extra_params,
+ double *__restrict state,
+ double const *__restrict sequence,
+ double *__restrict loss) {
int i, t;
double total = 0.0;
int count = 0;
@@ -147,7 +131,8 @@
__builtin_assume(b>0);
for (t = 0; t <= (c - 1) * b - 1; t += b)
{
- lstm_predict(l, b, main_params, extra_params, state, input, ypred);
+ lstm_predict_restrict(l, b, main_params, extra_params, state, input,
+ ypred);
lse = logsumexp(ypred, b);
for (i = 0; i < b; i++)
{
@@ -177,32 +162,17 @@
// * tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c
-void dlstm_objective(
- int l,
- int c,
- int b,
- double const* main_params,
- double* dmain_params,
- double const* extra_params,
- double* dextra_params,
- double* state,
- double const* sequence,
- double* loss,
- double* dloss
-)
-{
- __enzyme_autodiff(lstm_objective,
- enzyme_const, l,
- enzyme_const, c,
- enzyme_const, b,
- enzyme_dup, main_params, dmain_params,
- enzyme_dup, extra_params, dextra_params,
- enzyme_const, state,
- enzyme_const, sequence,
- enzyme_dupnoneed, loss, dloss
- );
+void dlstm_objective_restrict(int l, int c, int b, double const *main_params,
+ double *dmain_params, double const *extra_params,
+ double *dextra_params, double *state,
+ double const *sequence, double *loss,
+ double *dloss) {
+ __enzyme_autodiff(cxx_restrict_lstm_objective, enzyme_const, l,
+ enzyme_const, c, enzyme_const, b, enzyme_dup, main_params,
+ dmain_params, enzyme_dup, extra_params, dextra_params,
+ enzyme_const, state, enzyme_const, sequence,
+ enzyme_dupnoneed, loss, dloss);
}
-
}
@@ -728,3 +698,5 @@
}
#endif
+
+#include "lstm_mayalias.h"
diff --git a/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h b/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h
new file mode 100644
index 0000000..06401ff
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h
@@ -0,0 +1,160 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+/*
+ * File "lstm_b_tapenade_generated.c" is generated by Tapenade 3.14 (r7259) from this file.
+ * To reproduce such a generation you can use Tapenade CLI
+ * (can be downloaded from http://www-sop.inria.fr/tropics/tapenade/downloading.html)
+ *
+ * After installing use the next command to generate a file:
+ *
+ * tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c
+ *
+ * This will produce a file "lstm_tapenade_b.c" which content will be the same as the content of the file "lstm_b_tapenade_generated.c",
+ * except one-line header. Moreover a log-file "lstm_tapenade_b.msg" will be produced.
+ *
+ * NOTE: the code in "lstm_b_tapenade_generated.c" is wrong and won't work.
+ * REPAIRED SOURCE IS STORED IN THE FILE "lstm_b.c".
+ * You can either use diff tool or read "lstm_b.c" header to figure out what changes was performed to fix the code.
+ *
+ * NOTE: you can also use Tapenade web server (http://tapenade.inria.fr:8080/tapenade/index.jsp)
+ * for generating but the result can be slightly different.
+ */
+
+// #include "../adbench/lstm.h"
+
+extern "C" {
+// #include "lstm.h"
+
+// UTILS
+// Sigmoid on scalar
+// double sigmoid(double x)
+//{
+// return 1.0 / (1.0 + exp(-x));
+//}
+//
+//// log(sum(exp(x), 2))
+// double logsumexp(double const* vect, int sz)
+//{
+// double sum = 0.0;
+// int i;
+//
+// for (i = 0; i < sz; i++)
+// {
+// sum += exp(vect[i]);
+// }
+//
+// sum += 2;
+// return log(sum);
+// }
+
+// LSTM OBJECTIVE
+// The LSTM model
+void lstm_model(int hsize, double const *weight, double const *bias,
+ double *hidden, double *cell, double const *input) {
+ // TODO NOTE THIS
+ //__builtin_assume(hsize > 0);
+
+ double *gates = (double *)malloc(4 * hsize * sizeof(double));
+ double *forget = &(gates[0]);
+ double *ingate = &(gates[hsize]);
+ double *outgate = &(gates[2 * hsize]);
+ double *change = &(gates[3 * hsize]);
+
+ int i;
+ // caching input
+ // hidden (needed)
+ for (i = 0; i < hsize; i++) {
+ forget[i] = sigmoid(input[i] * weight[i] + bias[i]);
+ ingate[i] = sigmoid(hidden[i] * weight[hsize + i] + bias[hsize + i]);
+ outgate[i] =
+ sigmoid(input[i] * weight[2 * hsize + i] + bias[2 * hsize + i]);
+ change[i] = tanh(hidden[i] * weight[3 * hsize + i] + bias[3 * hsize + i]);
+ }
+
+ // caching cell (needed)
+ for (i = 0; i < hsize; i++) {
+ cell[i] = cell[i] * forget[i] + ingate[i] * change[i];
+ }
+
+ for (i = 0; i < hsize; i++) {
+ hidden[i] = outgate[i] * tanh(cell[i]);
+ }
+
+ free(gates);
+}
+
+// Predict LSTM output given an input
+void lstm_predict(int l, int b, double const *w, double const *w2, double *s,
+ double const *x, double *x2) {
+ int i;
+ for (i = 0; i < b; i++) {
+ x2[i] = x[i] * w2[i];
+ }
+
+ double *xp = x2;
+ for (i = 0; i <= 2 * l * b - 1; i += 2 * b) {
+ lstm_model(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]), &(s[i + b]), xp);
+ xp = &(s[i]);
+ }
+
+ for (i = 0; i < b; i++) {
+ x2[i] = xp[i] * w2[b + i] + w2[2 * b + i];
+ }
+}
+
+// LSTM objective (loss function)
+void cxx_mayalias_lstm_objective(int l, int c, int b, double const *main_params,
+ double const *extra_params, double *state,
+ double const *sequence, double *loss) {
+ int i, t;
+ double total = 0.0;
+ int count = 0;
+ const double *input = &(sequence[0]);
+ double *ypred = (double *)malloc(b * sizeof(double));
+ double *ynorm = (double *)malloc(b * sizeof(double));
+ const double *ygold;
+ double lse;
+
+ __builtin_assume(b > 0);
+ for (t = 0; t <= (c - 1) * b - 1; t += b) {
+ lstm_predict(l, b, main_params, extra_params, state, input, ypred);
+ lse = logsumexp(ypred, b);
+ for (i = 0; i < b; i++) {
+ ynorm[i] = ypred[i] - lse;
+ }
+
+ ygold = &(sequence[t + b]);
+ for (i = 0; i < b; i++) {
+ total += ygold[i] * ynorm[i];
+ }
+
+ count += b;
+ input = ygold;
+ }
+
+ *loss = -total / count;
+
+ free(ypred);
+ free(ynorm);
+}
+
+extern int enzyme_const;
+extern int enzyme_dup;
+extern int enzyme_dupnoneed;
+void __enzyme_autodiff(...) noexcept;
+
+// * tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c
+
+void dlstm_objective_mayalias(int l, int c, int b, double const *main_params,
+ double *dmain_params, double const *extra_params,
+ double *dextra_params, double *state,
+ double const *sequence, double *loss,
+ double *dloss) {
+ __enzyme_autodiff(cxx_mayalias_lstm_objective, enzyme_const, l, enzyme_const,
+ c, enzyme_const, b, enzyme_dup, main_params, dmain_params,
+ enzyme_dup, extra_params, dextra_params, enzyme_const,
+ state, enzyme_const, sequence, enzyme_dupnoneed, loss,
+ dloss);
+}
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs b/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs
new file mode 100644
index 0000000..937460f
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs
@@ -0,0 +1,56 @@
+#![feature(autodiff)]
+
+pub (crate) mod unsf;
+pub (crate) mod safe;
+use std::slice;
+
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) {
+ let l = l as usize;
+ let c = c as usize;
+ let b = b as usize;
+ unsafe {unsf::lstm_unsafe_objective(l,c,b,main_params,extra_params,state,sequence, loss);}
+}
+#[no_mangle]
+pub extern "C" fn rust_safe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) {
+ let l = l as usize;
+ let c = c as usize;
+ let b = b as usize;
+ let (main_params, extra_params, state, sequence) = unsafe {(
+ slice::from_raw_parts(main_params, 2*l*4*b),
+ slice::from_raw_parts(extra_params, 3*b),
+ slice::from_raw_parts_mut(state, 2*l*b),
+ slice::from_raw_parts(sequence, c*b)
+ )};
+
+ unsafe {
+ safe::lstm_objective(l,c,b,main_params,extra_params,state,sequence, &mut *loss);
+ }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) {
+ let l = l as usize;
+ let c = c as usize;
+ let b = b as usize;
+ unsafe {unsf::d_lstm_unsafe_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, res, d_res);}
+}
+#[no_mangle]
+pub extern "C" fn rust_safe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) {
+ let l = l as usize;
+ let c = c as usize;
+ let b = b as usize;
+ let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe {(
+ slice::from_raw_parts(main_params, 2*l*4*b),
+ slice::from_raw_parts_mut(d_main_params, 2*l*4*b),
+ slice::from_raw_parts(extra_params, 3*b),
+ slice::from_raw_parts_mut(d_extra_params, 3*b),
+ slice::from_raw_parts_mut(state, 2*l*b),
+ slice::from_raw_parts(sequence, c*b)
+ )};
+
+ unsafe {
+ safe::d_lstm_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, &mut *res, &mut *d_res);
+ }
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs b/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs
new file mode 100644
index 0000000..d6847a4
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs
@@ -0,0 +1,231 @@
+use std::slice;
+use std::autodiff::autodiff;
+
+// Sigmoid on scalar
+fn sigmoid(x: f64) -> f64 {
+ 1.0 / (1.0 + (-x).exp())
+}
+
+// log(sum(exp(x), 2))
+#[inline]
+fn logsumexp(vect: &[f64]) -> f64 {
+ let mut sum = 0.0;
+ for &val in vect {
+ sum += val.exp();
+ }
+ sum += 2.0; // Adding 2 to sum
+ sum.ln()
+}
+
+// LSTM OBJECTIVE
+// The LSTM model
+fn lstm_model(
+ hsize: usize,
+ weight: &[f64],
+ bias: &[f64],
+ hidden: &mut [f64],
+ cell: &mut [f64],
+ input: &[f64],
+) {
+ let mut gates = vec![0.0; 4 * hsize];
+ let gates = &mut gates[..4 * hsize];
+ let (a, b) = gates.split_at_mut(2 * hsize);
+ let ((forget, ingate), (outgate, change)) = (a.split_at_mut(hsize), b.split_at_mut(hsize));
+
+ //debug_assert_eq!(weight.len(), 4 * hsize);
+ //debug_assert_eq!(bias.len(), 4 * hsize);
+ //debug_assert_eq!(hidden.len(), hsize);
+ //debug_assert!(cell.len() >= hsize);
+ //debug_assert!(input.len() >= hsize);
+ // caching input
+ for i in 0..hsize {
+ forget[i] = sigmoid(input[i] * weight[i] + bias[i]);
+ ingate[i] = sigmoid(hidden[i] * weight[hsize + i] + bias[hsize + i]);
+ outgate[i] = sigmoid(input[i] * weight[2 * hsize + i] + bias[2 * hsize + i]);
+ change[i] = (hidden[i] * weight[3 * hsize + i] + bias[3 * hsize + i]).tanh();
+ }
+
+ // caching cell
+ for i in 0..hsize {
+ cell[i] = cell[i] * forget[i] + ingate[i] * change[i];
+ }
+
+ for i in 0..hsize {
+ hidden[i] = outgate[i] * cell[i].tanh();
+ }
+}
+
+// Predict LSTM output given an input
+fn lstm_predict(
+ l: usize,
+ b: usize,
+ w: &[f64],
+ w2: &[f64],
+ s: &mut [f64],
+ x: &[f64],
+ x2: &mut [f64],
+) {
+ for i in 0..b {
+ x2[i] = x[i] * w2[i];
+ }
+
+ let mut i = 0;
+ while i <= 2 * l * b - 1 {
+ // make borrow-checker happy with non-overlapping mutable references
+ let (xp, s1, s2) = if i == 0 {
+ let (s1, s2) = s.split_at_mut(b);
+ (x2.as_mut(), s1, s2)
+ } else {
+ let tmp = &mut s[i - 2 * b..];
+ let (a, d) = tmp.split_at_mut(2 * b);
+ let (d, c) = d.split_at_mut(b);
+
+ (a, d, c)
+ };
+
+ lstm_model(
+ b,
+ &w[i * 4..(i + b) * 4],
+ &w[(i + b) * 4..(i + 2 * b) * 4],
+ s1,
+ s2,
+ xp,
+ );
+
+ i += 2 * b;
+ }
+
+ let xp = &s[i - 2 * b..];
+
+ for i in 0..b {
+ x2[i] = xp[i] * w2[b + i] + w2[2 * b + i];
+ }
+}
+
+// LSTM objective (loss function)
+#[autodiff(
+ d_lstm_objective,
+ Reverse,
+ Const,
+ Const,
+ Const,
+ Duplicated,
+ Duplicated,
+ Const,
+ Const,
+ DuplicatedOnly
+)]
+pub(crate) fn lstm_objective(
+ l: usize,
+ c: usize,
+ b: usize,
+ main_params: &[f64],
+ extra_params: &[f64],
+ state: &mut [f64],
+ sequence: &[f64],
+ loss: &mut f64,
+) {
+ let mut total = 0.0;
+
+ let mut input = &sequence[..b];
+ let mut ypred = vec![0.0; b];
+ let mut ynorm = vec![0.0; b];
+
+ //debug_assert!(b > 0);
+
+ let limit = (c - 1) * b;
+ for j in 0..(c - 1) {
+ let t = j * b;
+ lstm_predict(l, b, main_params, extra_params, state, input, &mut ypred);
+ let lse = logsumexp(&ypred);
+ for i in 0..b {
+ ynorm[i] = ypred[i] - lse;
+ }
+
+ let ygold = &sequence[t + b..];
+ for i in 0..b {
+ total += ygold[i] * ynorm[i];
+ }
+
+ input = ygold;
+ }
+ let count = (c - 1) * b;
+
+ *loss = -total / count as f64;
+}
+
+#[no_mangle]
+pub extern "C" fn rust_lstm_objective(
+ l: usize,
+ c: usize,
+ b: usize,
+ main_params: *const f64,
+ extra_params: *const f64,
+ state: *mut f64,
+ sequence: *const f64,
+ loss: *mut f64,
+) {
+ let (main_params, extra_params, state, sequence) = unsafe {
+ (
+ slice::from_raw_parts(main_params, 2 * l * 4 * b),
+ slice::from_raw_parts(extra_params, 3 * b),
+ slice::from_raw_parts_mut(state, 2 * l * b),
+ slice::from_raw_parts(sequence, c * b),
+ )
+ };
+
+ unsafe {
+ lstm_objective(
+ l,
+ c,
+ b,
+ main_params,
+ extra_params,
+ state,
+ sequence,
+ &mut *loss,
+ );
+ }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dlstm_objective(
+ l: usize,
+ c: usize,
+ b: usize,
+ main_params: *const f64,
+ d_main_params: *mut f64,
+ extra_params: *const f64,
+ d_extra_params: *mut f64,
+ state: *mut f64,
+ sequence: *const f64,
+ res: *mut f64,
+ d_res: *mut f64,
+) {
+ let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe {
+ (
+ slice::from_raw_parts(main_params, 2 * l * 4 * b),
+ slice::from_raw_parts_mut(d_main_params, 2 * l * 4 * b),
+ slice::from_raw_parts(extra_params, 3 * b),
+ slice::from_raw_parts_mut(d_extra_params, 3 * b),
+ slice::from_raw_parts_mut(state, 2 * l * b),
+ slice::from_raw_parts(sequence, c * b),
+ )
+ };
+
+ unsafe {
+ d_lstm_objective(
+ l,
+ c,
+ b,
+ main_params,
+ d_main_params,
+ extra_params,
+ d_extra_params,
+ state,
+ sequence,
+ &mut *res,
+ &mut *d_res,
+ );
+ }
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs b/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs
new file mode 100644
index 0000000..498bf96
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs
@@ -0,0 +1,116 @@
+use std::autodiff::autodiff;
+
+// Sigmoid on scalar
+fn sigmoid(x: f64) -> f64 {
+ 1.0 / (1.0 + (-x).exp())
+}
+
+// log(sum(exp(x), 2))
+unsafe fn logsumexp(vect: *const f64, sz: usize) -> f64 {
+ let mut sum: f64 = 0.0;
+ for i in 0..sz {
+ sum += (*vect.add(i)).exp();
+ }
+ sum += 2.0; // Adding 2 to sum
+ sum.ln()
+}
+
+// LSTM OBJECTIVE
+// The LSTM model
+unsafe fn lstm_model(
+ hsize: usize,
+ weight: *const f64,
+ bias: *const f64,
+ hidden: *mut f64,
+ cell: *mut f64,
+ input: *const f64,
+) {
+// // TODO NOTE THIS
+// //__builtin_assume(hsize > 0);
+ let mut gates = vec![0.0; 4 * hsize];
+ let forget: *mut f64 = gates.as_mut_ptr();
+ let ingate: *mut f64 = gates[hsize..].as_mut_ptr();
+ let outgate: *mut f64 = gates[2 * hsize..].as_mut_ptr();
+ let change: *mut f64 = gates[3 * hsize..].as_mut_ptr();
+ //let (a,b) = gates.split_at_mut(2*hsize);
+ //let ((forget, ingate), (outgate, change)) = (
+ // a.split_at_mut(hsize), b.split_at_mut(hsize));
+
+ // caching input
+ for i in 0..hsize {
+ *forget.add(i) = sigmoid(*input.add(i) * *weight.add(i) + *bias.add(i));
+ *ingate.add(i) = sigmoid(*hidden.add(i) * *weight.add(hsize + i) + *bias.add(hsize + i));
+ *outgate.add(i) = sigmoid(*input.add(i) * *weight.add(2 * hsize + i) + *bias.add(2 * hsize + i));
+ *change.add(i) = (*hidden.add(i) * *weight.add(3 * hsize + i) + *bias.add(3 * hsize + i)).tanh();
+ }
+
+ // caching cell
+ for i in 0..hsize {
+ *cell.add(i) = *cell.add(i) * *forget.add(i) + *ingate.add(i) * *change.add(i);
+ }
+
+ for i in 0..hsize {
+ *hidden.add(i) = *outgate.add(i) * (*cell.add(i)).tanh();
+ }
+}
+
+// Predict LSTM output given an input
+unsafe fn lstm_predict(
+ l: usize,
+ b: usize,
+ w: *const f64,
+ w2: *const f64,
+ s: *mut f64,
+ x: *const f64,
+ x2: *mut f64,
+) {
+ for i in 0..b {
+ *x2.add(i) = *x.add(i) * *w2.add(i);
+ }
+
+ let mut xp = x2;
+ let stop = 2 * l * b;
+ for i in (0..=stop - 1).step_by(2 * b) {
+ lstm_model(b, w.add(i * 4), w.add((i + b) * 4), s.add(i), s.add(i + b), xp);
+ xp = s.add(i);
+ }
+
+ for i in 0..b {
+ *x2.add(i) = *xp.add(i) * *w2.add(b + i) + *w2.add(2 * b + i);
+ }
+}
+
+// LSTM objective (loss function)
+#[autodiff(d_lstm_unsafe_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Const, Const, DuplicatedOnly)]
+pub (crate) unsafe fn lstm_unsafe_objective(l: usize, c: usize, b: usize, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) {
+ let mut total = 0.0;
+ let mut count = 0;
+
+ //const double* input = &(sequence[0]);
+ let mut input = sequence;
+ let mut ypred = vec![0.0; b];
+ let mut ynorm = vec![0.0; b];
+ let mut lse;
+
+ assert!(b > 0);
+
+ let stop = (c - 1) * b;
+ for t in (0..=stop - 1).step_by(b) {
+ lstm_predict(l, b, main_params, extra_params, state, input, ypred.as_mut_ptr());
+ lse = logsumexp(ypred.as_mut_ptr(), b);
+ for i in 0..b {
+ ynorm[i] = ypred[i] - lse;
+ }
+
+ //let ygold = &sequence[t + b..];
+ let ygold = sequence.add(t + b);
+ for i in 0..b {
+ total += *ygold.add(i) * ynorm[i];
+ }
+
+ count += b;
+ input = ygold;
+ }
+
+ *loss = -total / count as f64;
+}
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock
new file mode 100644
index 0000000..93dcf6a
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock
@@ -0,0 +1,7 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "ode"
+version = "0.1.0"
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml
new file mode 100644
index 0000000..b7386a4
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "ode"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+
+[lib]
+crate-type = ["lib"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make b/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make
index 5abb283..87af95f 100644
--- a/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make
@@ -6,6 +6,10 @@
clean:
rm -f *.ll *.o results.txt results.json
+ cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/ode-real/target/release/libode.a: src/lib.rs Cargo.toml
+ RUSTFLAGS="-Z autodiff=Enable,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib
%-unopt.ll: %.cpp
clang++ $(BENCH) $(PTR) $^ -O2 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,17 +20,17 @@
%-opt.ll: %-raw.ll
opt $^ -o $@ -S
-ode.o: ode-opt.ll
+ode.o: ode-opt.ll $(dir)/benchmarks/ReverseMode/ode-real/target/release/libode.a
clang++ $(BENCH) -O2 $^ -o $@ $(BENCHLINK)
results.json: ode.o
- ./$^ 1000 | tee $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
- ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 | tee $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
+ numactl -C 1 ./$^ 1000 >> $@
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp b/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp
index 7c7113d..17007c8 100644
--- a/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp
+++ b/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp
@@ -24,20 +24,8 @@
return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec);
}
-#define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS
-#define BOOST_NO_EXCEPTIONS
#include <iostream>
-#include <boost/array.hpp>
-
-#include <boost/numeric/odeint.hpp>
-
-#include <boost/throw_exception.hpp>
-void boost::throw_exception(std::exception const & e){
- //do nothing
-}
-
using namespace std;
-using namespace boost::numeric::odeint;
#define N 32
#define xmin 0.
@@ -76,7 +64,7 @@
}
__attribute__((noinline))
-void brusselator_2d_loop(double* __restrict du, double* __restrict dv, const double* __restrict u, const double* __restrict v, const double* __restrict p, double t) {
+void brusselator_2d_loop_restrict(double* __restrict du, double* __restrict dv, const double* __restrict u, const double* __restrict v, const double* __restrict p, double t) {
double A = p[0];
double B = p[1];
double alpha = p[2];
@@ -107,33 +95,131 @@
}
}
-typedef boost::array< double , 2 * N * N > state_type;
+__attribute__((noinline))
+void brusselator_2d_loop_norestrict(double* du, double* dv, const double* u, const double* v, const double* p, double t) {
+ double A = p[0];
+ double B = p[1];
+ double alpha = p[2];
+ double dx = (double)1/(N-1);
-void lorenz( const state_type &x , state_type &dxdt , double t )
+ alpha = alpha/(dx*dx);
+
+ for(int i=0; i<N; i++) {
+ for(int j=0; j<N; j++) {
+
+ double x = RANGE(xmin, xmax, i, N);
+ double y = RANGE(ymin, ymax, j, N);
+
+ unsigned ip1 = (i == N-1) ? i : (i+1);
+ unsigned im1 = (i == 0) ? i : (i-1);
+
+ unsigned jp1 = (j == N-1) ? j : (j+1);
+ unsigned jm1 = (j == 0) ? j : (j-1);
+
+ double u2v = GET(u, i, j) * GET(u, i, j) * GET(v, i, j);
+
+ GETnb(du, i, j) = alpha*( GET(u, im1, j) + GET(u, ip1, j) + GET(u, i, jp1) + GET(u, i, jm1) - 4 * GET(u, i, j))
+ + B + u2v - (A + 1)*GET(u, i, j) + brusselator_f(x, y, t);
+
+ GETnb(dv, i, j) = alpha*( GET(v, im1, j) + GET(v, ip1, j) + GET(v, i, jp1) + GET(v, i, jm1) - 4 * GET(v, i, j))
+ + A * GET(u, i, j) - u2v;
+ }
+ }
+}
+
+typedef double state_type[2*N*N];
+
+void lorenz_norestrict( const state_type &x, state_type &dxdt, double t )
{
// Extract the parameters
double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. };
- brusselator_2d_loop(dxdt.c_array(), dxdt.c_array() + N * N, x.data(), x.data() + N * N, p, t);
+ brusselator_2d_loop_norestrict(dxdt, dxdt + N * N, x, x + N * N, p, t);
}
-// init_brusselator(x.c_array(), x.c_array() + N*N)
+void lorenz_restrict( const state_type &x, state_type &dxdt, double t )
+{
+ // Extract the parameters
+ double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. };
+ brusselator_2d_loop_restrict(dxdt, dxdt + N * N, x, x + N * N, p, t);
+}
-double foobar(const double* p, const state_type x, const state_type adjoint, double t) {
+extern "C" void rust_lorenz_safe(const double* x, double* dxdt, double t);
+extern "C" void rust_dbrusselator_2d_loop_safe(double* adjoint, const double* x, double* dx, const double* p, double* dp, double t);
+extern "C" void rust_lorenz_unsf(const double* x, double* dxdt, double t);
+extern "C" void rust_dbrusselator_2d_loop_unsf(double* adjoint, const double* x, double* dx, const double* p, double* dp, double t);
+
+double rustfoobar_unsf(const double *p, const state_type x, const state_type adjoint, double t) {
+ double dp[3] = { 0. };
+
+ state_type dx = { 0. };
+
+ state_type dadjoint_inp;// = adjoint
+ for (int i = 0; i < N * N; i++) {
+ dadjoint_inp[i] = adjoint[i];
+ }
+
+ rust_dbrusselator_2d_loop_unsf(dadjoint_inp, x, dx, p, dp, t);
+ return dx[0];
+}
+
+double rustfoobar_safe(const double *p, const state_type x, const state_type adjoint, double t) {
+ double dp[3] = { 0. };
+
+ state_type dx = { 0. };
+
+ state_type dadjoint_inp;// = adjoint
+ for (int i = 0; i < N * N; i++) {
+ dadjoint_inp[i] = adjoint[i];
+ }
+
+ rust_dbrusselator_2d_loop_safe(dadjoint_inp, x, dx, p, dp, t);
+ return dx[0];
+}
+
+double foobar_restrict(const double* p, const state_type x, const state_type adjoint, double t) {
double dp[3] = { 0. };
state_type dx = { 0. };
- state_type dadjoint_inp = adjoint;
+ state_type dadjoint_inp;// = adjoint
+ for (int i = 0; i < N * N; i++) {
+ dadjoint_inp[i] = adjoint[i];
+ }
state_type dxdu;
- __enzyme_autodiff<void>(brusselator_2d_loop,
-// enzyme_dup, dxdu.c_array(), dadjoint_inp.c_array(),
-// enzyme_dup, dxdu.c_array() + N * N, dadjoint_inp.c_array() + N * N,
- enzyme_dupnoneed, nullptr, dadjoint_inp.data(),
- enzyme_dupnoneed, nullptr, dadjoint_inp.data() + N * N,
- enzyme_dup, x.data(), dx.data(),
- enzyme_dup, x.data() + N * N, dx.data() + N * N,
+ __enzyme_autodiff<void>(brusselator_2d_loop_restrict,
+ enzyme_dup, dxdu, dadjoint_inp,
+ enzyme_dup, dxdu + N * N, dadjoint_inp + N * N,
+ // enzyme_dupnoneed, nullptr, dadjoint_inp,
+ // enzyme_dupnoneed, nullptr, dadjoint_inp + N * N,
+ enzyme_dup, x, dx,
+ enzyme_dup, x + N * N, dx + N * N,
+ enzyme_dup, p, dp,
+ enzyme_const, t);
+
+ return dx[0];
+}
+
+double foobar_norestrict(const double* p, const state_type x, const state_type adjoint, double t) {
+ double dp[3] = { 0. };
+
+ state_type dx = { 0. };
+
+ state_type dadjoint_inp;// = adjoint
+ for (int i = 0; i < N * N; i++) {
+ dadjoint_inp[i] = adjoint[i];
+ }
+
+ state_type dxdu;
+
+ __enzyme_autodiff<void>(brusselator_2d_loop_norestrict,
+ enzyme_dup, dxdu, dadjoint_inp,
+ enzyme_dup, dxdu + N * N, dadjoint_inp + N * N,
+ // enzyme_dupnoneed, nullptr, dadjoint_inp,
+ // enzyme_dupnoneed, nullptr, dadjoint_inp + N * N,
+ enzyme_dup, x, dx,
+ enzyme_dup, x + N * N, dx + N * N,
enzyme_dup, p, dp,
enzyme_const, t);
@@ -486,14 +572,17 @@
state_type dx = { 0. };
- state_type dadjoint_inp = adjoint;
+ state_type dadjoint_inp;// = adjoint
+ for (int i = 0; i < N * N; i++) {
+ dadjoint_inp[i] = adjoint[i];
+ }
state_type dxdu;
- brusselator_2d_loop_b(nullptr, dadjoint_inp.data(),
- nullptr, dadjoint_inp.data() + N * N,
- x.data(), dx.data(),
- x.data() + N * N, dx.data() + N * N,
+ brusselator_2d_loop_b(nullptr, dadjoint_inp,
+ nullptr, dadjoint_inp + N * N,
+ x, dx,
+ x + N * N, dx + N * N,
p, dp,
t);
@@ -505,10 +594,10 @@
const double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. };
state_type x;
- init_brusselator(x.data(), x.data() + N * N);
+ init_brusselator(x, x + N * N);
state_type adjoint;
- init_brusselator(adjoint.data(), adjoint.data() + N * N);
+ init_brusselator(adjoint, adjoint + N * N);
double t = 2.1;
@@ -542,174 +631,97 @@
double res;
for(int i=0; i<10000; i++)
- res = foobar(p, x, adjoint, t);
+ res = foobar_norestrict(p, x, adjoint, t);
gettimeofday(&end, NULL);
- printf("Enzyme combined %0.6f res=%f\n", tdiff(&start, &end), res);
+ printf("C++ Enzyme combined mayalias %0.6f res=%f\n", tdiff(&start, &end), res);
}
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double res;
+ for(int i=0; i<10000; i++)
+ res = foobar_restrict(p, x, adjoint, t);
+
+ gettimeofday(&end, NULL);
+ printf("C++ Enzyme combined restrict %0.6f res=%f\n", tdiff(&start, &end), res);
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double res;
+ for(int i=0; i<10000; i++)
+ res = rustfoobar_safe(p, x, adjoint, t);
+
+ gettimeofday(&end, NULL);
+ printf("Rust Enzyme combined safe %0.6f res=%f\n", tdiff(&start, &end), res);
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+
+ double res;
+ for(int i=0; i<10000; i++)
+ res = rustfoobar_unsf(p, x, adjoint, t);
+
+ gettimeofday(&end, NULL);
+ printf("Rust Enzyme combined unsf %0.6f res=%f\n", tdiff(&start, &end), res);
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ state_type x2;
+
+ for(int i=0; i<10000; i++) {
+ lorenz_norestrict(x, x2, t);
+ }
+
+ gettimeofday(&end, NULL);
+ printf("C++ fwd mayalias %0.6f res=%f\n", tdiff(&start, &end), x2[0]);
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ state_type x2;
+
+ for(int i=0; i<10000; i++) {
+ lorenz_restrict(x, x2, t);
+ }
+
+ gettimeofday(&end, NULL);
+ printf("C++ fwd restrict %0.6f res=%f\n", tdiff(&start, &end), x2[0]);
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ state_type x2;
+
+ for(int i=0; i<10000; i++)
+ rust_lorenz_safe(x, x2, t);
+
+ gettimeofday(&end, NULL);
+ printf("Rust fwd safe %0.6f res=%f\n\n", tdiff(&start, &end), x2[0]);
+ }
+
+ {
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ state_type x2;
+
+ for(int i=0; i<10000; i++)
+ rust_lorenz_unsf(x, x2, t);
+
+ gettimeofday(&end, NULL);
+ printf("Rust fwd unsf %0.6f res=%f\n\n", tdiff(&start, &end), x2[0]);
+ }
+
//printf("res=%f\n", foobar(1000));
}
-
-
-#if 0
-
-typedef boost::array< double , 6 > state_type;
-
-void lorenz( const state_type &x , state_type &dxdt , double t )
-{
- // Extract the parameters
- double k1 = x[3];
- double k2 = x[4];
- double k3 = x[5];
-
- dxdt[0] = -k1 * x[0] + k3 * x[1] * x[2];
- dxdt[1] = k1 * x[0] - k2 * x[1] * x[1] - k3 * x[1] * x[2];
- dxdt[2] = k2 * x[1] * x[1];
-
- // Don't change the parameters p
- dxdt[3] = 0;
- dxdt[4] = 0;
- dxdt[5] = 0;
-}
-
-double foobar(double* p, uint64_t iters) {
- state_type x = { 1.0, 0, 0, p[0], p[1], p[2] }; // initial conditions
- double t = 1e5;
- typedef controlled_runge_kutta< runge_kutta_dopri5< state_type , typename state_type::value_type , state_type , double > > stepper_type;
- //typedef euler< state_type , typename state_type::value_type , state_type , double > stepper_type;
- integrate_const( stepper_type(), lorenz , x , 0.0 , t, t/iters );
-
- return x[0];
-}
-
-typedef boost::array< adouble , 6 > astate_type;
-
-void alorenz( const astate_type &x , astate_type &dxdt , adouble t )
-{
- // Extract the parameters
- adouble k1 = x[3];
- adouble k2 = x[4];
- adouble k3 = x[5];
-
- dxdt[0] = -k1 * x[0] + k3 * x[1] * x[2];
- dxdt[1] = k1 * x[0] - k2 * x[1] * x[1] - k3 * x[1] * x[2];
- dxdt[2] = k2 * x[1] * x[1];
-
- // Don't change the parameters p
- dxdt[3] = 0;
- dxdt[4] = 0;
- dxdt[5] = 0;
-}
-
-adouble afoobar(adouble* p, uint64_t iters) {
- astate_type x = { 1.0, 0, 0, p[0], p[1], p[2] }; // initial conditions
- double t = 1e5;
- typedef controlled_runge_kutta< runge_kutta_dopri5< astate_type , typename astate_type::value_type , astate_type , adouble > > stepper_type;
- //typedef euler< astate_type , typename astate_type::value_type , astate_type , adouble > stepper_type;
- integrate_const( stepper_type(), alorenz , x , 0.0 , t, t/iters );
-
- return x[0];
-}
-
-static
-double afoobar_and_gradient(double* p_in, double* dp_out, uint64_t iters) {
- adept::Stack stack;
- adouble x[3] = { p_in[0], p_in[1], p_in[2] };
- stack.new_recording();
- adouble y = afoobar(x, iters);
- y.set_gradient(1.0);
- stack.compute_adjoint();
- for(int i=0; i<3; i++)
- dp_out[i] = x[i].get_gradient();
- return y.value();
-}
-
-static void adept_sincos(uint64_t iters) {
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
- double p[3] = { 0.04,3e7,1e4 };
- double res = foobar(p, iters);
-
- gettimeofday(&end, NULL);
- printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res);
- }
-
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
- adept::Stack stack;
- adouble p[3] = { 0.04,3e7,1e4 };
- // stack.new_recording();
- adouble resa = afoobar(p, iters);
- double res = resa.value();
-
- gettimeofday(&end, NULL);
- printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res);
- }
-
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
- double p[3] = { 0.04,3e7,1e4 };
- double dp[3] = { 0 };
- afoobar_and_gradient(p, dp, iters);
-
- gettimeofday(&end, NULL);
- printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), dp[0]);
- }
-}
-
-static void enzyme_sincos(double inp, uint64_t iters) {
-
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
- double p[3] = { 0.04,3e7,1e4 };
- double res = foobar(p, iters);
-
- gettimeofday(&end, NULL);
- printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res);
- }
-
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
- double p[3] = { 0.04,3e7,1e4 };
- double res = foobar(p, iters);
-
- gettimeofday(&end, NULL);
- printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res);
- }
-
- {
- struct timeval start, end;
- gettimeofday(&start, NULL);
-
- double p[3] = { 0.04,3e7,1e4 };
- double dp[3] = { 0 };
- __enzyme_autodiff<void>(foobar, p, dp, iters);
-
- gettimeofday(&end, NULL);
- printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), dp[0]);
- }
-}
-
-int main(int argc, char** argv) {
-
- int max_iters = atoi(argv[1]) ;
- double inp = 2.1;
-
- //for(int iters=max_iters/20; iters<=max_iters; iters+=max_iters/20) {
- auto iters = max_iters;
- printf("iters=%d\n", iters);
- adept_sincos(inp, iters);
- enzyme_sincos(inp, iters);
- //}
-}
-#endif
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs
new file mode 100644
index 0000000..4fbc7e7
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs
@@ -0,0 +1,100 @@
+#![feature(autodiff)]
+#![feature(slice_as_chunks)]
+#![feature(iter_next_chunk)]
+#![feature(array_ptr_get)]
+#![allow(non_snake_case)]
+#![allow(non_camel_case_types)]
+#![allow(non_upper_case_globals)]
+
+pub mod safe;
+pub mod unsf;
+
+type StateType = [f64; 2 * N * N];
+
+const N: usize = 32;
+
+
+#[no_mangle]
+pub extern "C" fn rust_lorenz_unsf(x: *const StateType, dxdt: *mut StateType, t: f64) {
+ let x: &StateType = unsafe { &*x };
+ let dxdt: &mut StateType = unsafe { &mut *dxdt };
+ unsafe {unsf::lorenz(x, dxdt, t)};
+}
+
+
+#[no_mangle]
+pub extern "C" fn rust_lorenz_safe(x: *const StateType, dxdt: *mut StateType, t: f64) {
+ let x: &StateType = unsafe { &*x };
+ let dxdt: &mut StateType = unsafe { &mut *dxdt };
+ safe::lorenz(x, dxdt, t);
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dbrusselator_2d_loop_unsf(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) {
+ let mut null1 = [0.; 1 * N * N];
+ let mut null2 = [0.; 1 * N * N];
+ let dx1: *mut f64 = dx.as_mut_ptr();
+ let dx2: *mut f64 = unsafe { dx.as_mut_ptr().add(N*N) };
+ let dadj1: *mut f64 = adjoint.as_mut_ptr();
+ let dadj2: *mut f64 = unsafe { adjoint.as_mut_ptr().add(N*N) };
+ let x1: *const f64 = x.as_ptr();
+ let x2: *const f64 = unsafe { x.as_ptr().add(N*N) };
+
+ unsafe {unsf::dbrusselator_2d_loop_unsf(null1.as_mut_ptr(), dadj1,
+ null2.as_mut_ptr(), dadj2,
+ x1, dx1,
+ x2, dx2,
+ p as *mut f64, dp as *mut f64, t)};
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dbrusselator_2d_loop_safe(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) {
+ let x: &StateType = unsafe { &*x };
+ let dx: &mut StateType = unsafe { &mut *dx };
+ let adjoint: &mut StateType = unsafe { &mut *adjoint };
+
+ let p: &[f64;3] = unsafe { &*p };
+ let dp: &mut [f64;3] = unsafe { &mut *dp };
+
+ assert!(p[0] == 3.4);
+ assert!(p[1] == 1.);
+ assert!(p[2] == 10.);
+ assert!(t == 2.1);
+
+ //let mut x1 = [0.; 2 * N * N];
+ //let mut dx1 = [0.; 2 *N * N];
+ //let (tmp1, tmp2) = x1.split_at_mut(N * N);
+ //let mut x1: [f64; N * N] = tmp1.try_into().unwrap();
+ //let mut x2: [f64; N * N] = tmp2.try_into().unwrap();
+ //init_brusselator(&mut x1, &mut x2);
+ //for i in 0..N*N {
+ // let tmp = (x1[i] - x[i]).abs();
+ // if (tmp / x[i] > 1e-5) {
+ // dbg!(tmp);
+ // dbg!(tmp / x[i]);
+ // dbg!(i);
+ // dbg!(x1[i]);
+ // dbg!(x[i]);
+ // println!("x1[{}] = {} != x[{}] = {}", i, x1[i], i, x[i]);
+ // panic!();
+ // }
+ //}
+
+ // Alternative ways to split the inputs
+ //let [ mut dx1, mut dx2]: [[f64; N*N]; 2] = unsafe { *std::mem::transmute::<*mut StateType, &mut [[f64; N*N]; 2]>(dx) };
+ //let [dx1, dx2]: &mut [[f64; N*N];2] = unsafe { dx.cast::<[[f64; N*N]; 2]>().as_mut().unwrap() };
+
+ // https://discord.com/channels/273534239310479360/273541522815713281/1236945105601040446
+ let ([dx1, dx2], []): (&mut [[f64; N*N]], &mut [f64]) = dx.as_chunks_mut() else { unreachable!() };
+ let ([dadj1, dadj2], []): (&mut [[f64; N*N]], &mut [f64])= adjoint.as_chunks_mut() else { unreachable!() };
+ let ([x1, x2], []): (&[[f64; N*N]], &[f64])= x.as_chunks() else { unreachable!() };
+
+ let mut null1 = [0.; 1 * N * N];
+ let mut null2 = [0.; 1 * N * N];
+ safe::dbrusselator_2d_loop(&mut null1, dadj1,
+ &mut null2, dadj2,
+ x1, dx1,
+ x2, dx2,
+ p, dp, t);
+ return;
+}
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs
new file mode 100644
index 0000000..ddf3685
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs
@@ -0,0 +1,75 @@
+use std::autodiff::autodiff;
+
+const N: usize = 32;
+const xmin: f64 = 0.;
+const xmax: f64 = 1.;
+const ymin: f64 = 0.;
+const ymax: f64 = 1.;
+
+#[inline(always)]
+fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 {
+ (max - min) / (N_var as f64 - 1.) * i as f64 + min
+}
+
+fn brusselator_f(x: f64, y: f64, t: f64) -> f64 {
+ let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1;
+ let eq2 = t >= 1.1;
+ if eq1 && eq2 {
+ 5.0
+ } else {
+ 0.0
+ }
+}
+
+#[expect(unused)]
+fn init_brusselator(u: &mut [f64], v: &mut [f64]) {
+ assert!(u.len() == N * N);
+ assert!(v.len() == N * N);
+ for i in 0..N {
+ for j in 0..N {
+ let x = range(xmin, xmax, i, N);
+ let y = range(ymin, ymax, j, N);
+ u[N * i + j] = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt();
+ v[N * i + j] = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt();
+ }
+ }
+}
+
+#[no_mangle]
+#[autodiff(dbrusselator_2d_loop, Reverse, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)]
+pub fn brusselator_2d_loop(d_u: &mut [f64;N*N], d_v: &mut [f64;N*N], u: &[f64;N*N], v: &[f64;N*N], p: &[f64;3], t: f64) {
+ let A = p[0];
+ let B = p[1];
+ let alpha = p[2];
+ let dx = 1. / (N - 1) as f64;
+ let alpha = alpha / (dx * dx);
+ for i in 0..N {
+ for j in 0..N {
+ let x = range(xmin, xmax, i, N);
+ let y = range(ymin, ymax, j, N);
+ let ip1 = if i == N - 1 { i } else { i + 1 };
+ let im1 = if i == 0 { i } else { i - 1 };
+ let jp1 = if j == N - 1 { j } else { j + 1 };
+ let jm1 = if j == 0 { j } else { j - 1 };
+ let u2v = u[N * i + j] * u[N * i + j] * v[N * i + j];
+ d_u[N * i + j] = alpha * (u[N * im1 + j] + u[N * ip1 + j] + u[N * i + jp1] + u[N * i + jm1] - 4. * u[N * i + j])
+ + B + u2v - (A + 1.) * u[N * i + j] + brusselator_f(x, y, t);
+ d_v[N * i + j] = alpha * (v[N * im1 + j] + v[N * ip1 + j] + v[N * i + jp1] + v[N * i + jm1] - 4. * v[N * i + j])
+ + A * u[N * i + j] - u2v;
+ }
+ }
+}
+
+pub type StateType = [f64; 2 * N * N];
+
+pub fn lorenz(x: &StateType, dxdt: &mut StateType, t: f64) {
+ let p = [3.4, 1., 10.];
+ let (tmp1, tmp2) = dxdt.split_at_mut(N * N);
+ let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap();
+ let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap();
+ let (tmp1, tmp2) = x.split_at(N * N);
+ let u: [f64; N * N] = tmp1.try_into().unwrap();
+ let v: [f64; N * N] = tmp2.try_into().unwrap();
+ brusselator_2d_loop(&mut dxdt1, &mut dxdt2, &u, &v, &p, t);
+}
+
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs
new file mode 100644
index 0000000..9f1e400
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs
@@ -0,0 +1,79 @@
+use std::autodiff::autodiff;
+
+const N: usize = 32;
+const xmin: f64 = 0.;
+const xmax: f64 = 1.;
+const ymin: f64 = 0.;
+const ymax: f64 = 1.;
+
+#[inline(always)]
+fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 {
+ (max - min) / (N_var as f64 - 1.) * i as f64 + min
+}
+
+fn brusselator_f(x: f64, y: f64, t: f64) -> f64 {
+ let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1;
+ let eq2 = t >= 1.1;
+ if eq1 && eq2 {
+ 5.0
+ } else {
+ 0.0
+ }
+}
+
+#[expect(unused)]
+unsafe fn init_brusselator(u: *mut f64, v: *mut f64) {
+ for i in 0..N {
+ for j in 0..N {
+ let x = range(xmin, xmax, i, N);
+ let y = range(ymin, ymax, j, N);
+ *u.add(N * i + j) = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt();
+ *v.add(N * i + j) = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt();
+ }
+ }
+}
+
+#[no_mangle]
+#[autodiff(dbrusselator_2d_loop_unsf, Reverse, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)]
+pub unsafe fn brusselator_2d_loop_unsf(d_u: *mut f64, d_v: *mut f64, u: *const f64, v: *const f64, p: *const f64, t: f64) {
+ let A = *p.add(0);
+ let B = *p.add(1);
+ let alpha = *p.add(2);
+ let dx = 1. / (N - 1) as f64;
+ let alpha = alpha / (dx * dx);
+ for i in 0..N {
+ for j in 0..N {
+ let x = range(xmin, xmax, i, N);
+ let y = range(ymin, ymax, j, N);
+ let ip1 = if i == N - 1 { i } else { i + 1 };
+ let im1 = if i == 0 { i } else { i - 1 };
+ let jp1 = if j == N - 1 { j } else { j + 1 };
+ let jm1 = if j == 0 { j } else { j - 1 };
+ let u2v = *u.add(N * i + j) * *u.add(N * i + j) * *v.add(N * i + j);
+ *d_u.add(N * i + j) = alpha * (*u.add(N * im1 + j) + *u.add(N * ip1 + j) + *u.add(N * i + jp1) + *u.add(N * i + jm1) - 4. * *u.add(N * i + j))
+ + B + u2v - (A + 1.) * *u.add(N * i + j) + brusselator_f(x, y, t);
+ *d_v.add(N * i + j) = alpha * (*v.add(N * im1 + j) + *v.add(N * ip1 + j) + *v.add(N * i + jp1) + *v.add(N * i + jm1) - 4. * *v.add(N * i + j))
+ + A * *u.add(N * i + j) - u2v;
+ }
+ }
+}
+
+type StateType = [f64; 2 * N * N];
+
+pub unsafe fn lorenz(x: *const StateType, dxdt: *mut StateType, t: f64) {
+ let p = [3.4, 1., 10.];
+ let x = x as *const f64;
+ let dxdt = dxdt as *mut f64;
+ let dxdt1: *mut f64 = dxdt as *mut f64;
+ let dxdt2: *mut f64 = unsafe {dxdt.add(N * N)} as *mut f64;
+ //let (tmp1, tmp2) = dxdt.split_at_mut(N * N);
+ //let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap();
+ //let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap();
+ let u: *const f64 = x as *const f64;
+ let v: *const f64 = unsafe{x.add(N * N)} as *const f64;
+ //let (tmp1, tmp2) = x.split_at(N * N);
+ //let u: [f64; N * N] = tmp1.try_into().unwrap();
+ //let v: [f64; N * N] = tmp2.try_into().unwrap();
+ unsafe {brusselator_2d_loop_unsf(dxdt1 as *mut f64, dxdt2 as *mut f64, u as *const f64, v as *const f64, p.as_ptr(), t)};
+}
+