Rust adbench

Co-authored-by: Lorenz Schmidt <bytesnake@mailbox.org>
Co-authored-by: Jed Brown <jed@jedbrown.org>
Co-authored-by: William Moses <gh@wsmoses.com>
diff --git a/.gitignore b/.gitignore
index 5e7285d..62c76d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,6 +8,7 @@
 enzyme/benchmarks/ReverseMode/*/*.bc
 enzyme/benchmarks/ReverseMode/*/*.o
 enzyme/benchmarks/ReverseMode/*/*.exe
+enzyme/benchmarks/ReverseMode/*/target/
 enzyme/benchmarks/ReverseMode/*/results.txt
 enzyme/benchmarks/ReverseMode/*/results.json
 .cache
diff --git a/enzyme/benchmarks/ReverseMode/adbench/ba.h b/enzyme/benchmarks/ReverseMode/adbench/ba.h
index 3ade86a..6a3f977 100644
--- a/enzyme/benchmarks/ReverseMode/adbench/ba.h
+++ b/enzyme/benchmarks/ReverseMode/adbench/ba.h
@@ -115,60 +115,68 @@
 };
 
 extern "C" {
-    void ba_objective(
-        int n,
-        int m,
-        int p,
-        double const* cams,
-        double const* X,
-        double const* w,
-        int const* obs,
-        double const* feats,
-        double* reproj_err,
-        double* w_err
-    );
+void ba_objective_restrict(int n, int m, int p, double const *cams,
+                           double const *X, double const *w, int const *obs,
+                           double const *feats, double *reproj_err,
+                           double *w_err);
 
-    void dcompute_reproj_error(
-        double const* cam,
-        double * dcam,
-        double const* X,
-        double * dX,
-        double const* w,
-        double * wb,
-        double const* feat,
-        double *err,
-        double *derr
-    );
+void ba_objective(int n, int m, int p, double const *cams, double const *X,
+                  double const *w, int const *obs, double const *feats,
+                  double *reproj_err, double *w_err);
 
-    void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr);
+void rust2_unsafe_ba_objective(int n, int m, int p, double const *cams,
+                               double const *X, double const *w, int const *obs,
+                               double const *feats, double *reproj_err,
+                               double *w_err);
 
-    void compute_reproj_error_b(
-        double const* cam,
-        double * dcam,
-        double const* X,
-        double * dX,
-        double const* w,
-        double * wb,
-        double const* feat,
-        double *err,
-        double *derr
-    );
+void rust2_ba_objective(int n, int m, int p, double const *cams,
+                        double const *X, double const *w, int const *obs,
+                        double const *feats, double *reproj_err, double *w_err);
 
-    void compute_zach_weight_error_b(double const* w, double* dw, double* err, double* derr);
+void dcompute_reproj_error_restrict(double const *cam, double *dcam,
+                                    double const *X, double *dX,
+                                    double const *w, double *wb,
+                                    double const *feat, double *err,
+                                    double *derr);
 
-    void adept_compute_reproj_error(
-        double const* cam,
-        double * dcam,
-        double const* X,
-        double * dX,
-        double const* w,
-        double * wb,
-        double const* feat,
-        double *err,
-        double *derr
-    );
+void dcompute_zach_weight_error_restrict(double const *w, double *dw,
+                                         double *err, double *derr);
 
-    void adept_compute_zach_weight_error(double const* w, double* dw, double* err, double* derr);
+void dcompute_reproj_error(double const *cam, double *dcam, double const *X,
+                           double *dX, double const *w, double *wb,
+                           double const *feat, double *err, double *derr);
+
+void dcompute_zach_weight_error(double const *w, double *dw, double *err,
+                                double *derr);
+
+void compute_reproj_error_b(double const *cam, double *dcam, double const *X,
+                            double *dX, double const *w, double *wb,
+                            double const *feat, double *err, double *derr);
+
+void compute_zach_weight_error_b(double const *w, double *dw, double *err,
+                                 double *derr);
+
+void adept_compute_reproj_error(double const *cam, double *dcam,
+                                double const *X, double *dX, double const *w,
+                                double *wb, double const *feat, double *err,
+                                double *derr);
+
+void adept_compute_zach_weight_error(double const *w, double *dw, double *err,
+                                     double *derr);
+
+void rust_unsafe_dcompute_reproj_error(double const *cam, double *dcam,
+                                       double const *X, double *dX,
+                                       double const *w, double *wb,
+                                       double const *feat, double *err,
+                                       double *derr);
+
+void rust_dcompute_reproj_error(double const *cam, double *dcam,
+                                double const *X, double *dX, double const *w,
+                                double *wb, double const *feat, double *err,
+                                double *derr);
+
+void rust_dcompute_zach_weight_error(double const *w, double *dw, double *err,
+                                     double *derr);
 }
 
 void read_ba_instance(const string& fn,
@@ -335,10 +343,22 @@
     std::string path = "/mnt/Data/git/Enzyme/apps/ADBench/data/ba/ba1_n49_m7776_p31843.txt";
 
     std::vector<std::string> paths = {
-        "ba10_n1197_m126327_p563734.txt",  "ba14_n356_m226730_p1255268.txt",   "ba18_n1936_m649673_p5213733.txt",    "ba2_n21_m11315_p36455.txt",    "ba6_n539_m65220_p277273.txt",  "test.txt",
-        "ba11_n1723_m156502_p678718.txt",  "ba15_n1102_m780462_p4052340.txt",  "ba19_n4585_m1324582_p9125125.txt",   "ba3_n161_m48126_p182072.txt",  "ba7_n93_m61203_p287451.txt",
-        "ba12_n253_m163691_p899155.txt",   "ba16_n1544_m942409_p4750193.txt",  "ba1_n49_m7776_p31843.txt",           "ba4_n372_m47423_p204472.txt",  "ba8_n88_m64298_p383937.txt",
-        "ba13_n245_m198739_p1091386.txt",  "ba17_n1778_m993923_p5001946.txt",  "ba20_n13682_m4456117_p2987644.txt",  "ba5_n257_m65132_p225911.txt",  "ba9_n810_m88814_p393775.txt",
+        "ba10_n1197_m126327_p563734.txt",
+        "ba14_n356_m226730_p1255268.txt", //  "ba18_n1936_m649673_p5213733.txt",
+                                          //  "ba2_n21_m11315_p36455.txt",
+                                          //  "ba6_n539_m65220_p277273.txt",
+                                          //  "test.txt",
+        //       "ba11_n1723_m156502_p678718.txt",
+        //       "ba15_n1102_m780462_p4052340.txt",
+        //       "ba19_n4585_m1324582_p9125125.txt",
+        //       "ba3_n161_m48126_p182072.txt",  "ba7_n93_m61203_p287451.txt",
+        //       "ba12_n253_m163691_p899155.txt",
+        //       "ba16_n1544_m942409_p4750193.txt",  "ba1_n49_m7776_p31843.txt",
+        //       "ba4_n372_m47423_p204472.txt",  "ba8_n88_m64298_p383937.txt",
+        //       "ba13_n245_m198739_p1091386.txt",
+        //       "ba17_n1778_m993923_p5001946.txt",
+        //       "ba20_n13682_m4456117_p2987644.txt",
+        //       "ba5_n257_m65132_p225911.txt",  "ba9_n810_m88814_p393775.txt",
     };
 
     std::ofstream jsonfile("results.json", std::ofstream::trunc);
@@ -358,27 +378,6 @@
         BASparseMat(input.n, input.m, input.p)
     };
 
-    //BASparseMat(this->input.n, this->input.m, this->input.p)
-
-    /*
-    ba_objective(
-        input.n,
-        input.m,
-        input.p,
-        input.cams.data(),
-        input.X.data(),
-        input.w.data(),
-        input.obs.data(),
-        input.feats.data(),
-        result.reproj_err.data(),
-        result.w_err.data()
-    );
-
-    for(unsigned i=0; i<input.p; i++) {
-        //printf("w_err[%d]=%f reproj_err[%d]=%f, reproj_err[%d]=%f\n", i, result.w_err[i], 2*i, result.reproj_err[2*i], 2*i+1, result.reproj_err[2*i+1]);
-    }
-    */
-
     {
       struct timeval start, end;
       gettimeofday(&start, NULL);
@@ -409,10 +408,198 @@
         BASparseMat(input.n, input.m, input.p)
     };
 
-    //BASparseMat(this->input.n, this->input.m, this->input.p)
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<adept_compute_reproj_error,
+                         adept_compute_zach_weight_error>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Adept combined %0.6f\n", tdiff(&start, &end));
+      json adept;
+      adept["name"] = "Adept combined";
+      adept["runtime"] = tdiff(&start, &end);
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.J.vals[i]);
+        adept["result"].push_back(result.J.vals[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(adept);
+    }
+    }
 
-    /*
-    ba_objective(
+    {
+
+    struct BAInput input;
+    read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+                     input.X, input.w, input.obs, input.feats);
+
+    struct BAOutput result = {std::vector<double>(2 * input.p),
+                              std::vector<double>(input.p),
+                              BASparseMat(input.n, input.m, input.p)};
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<dcompute_reproj_error_restrict,
+                         dcompute_zach_weight_error_restrict>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme restrict c++ combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Enzyme restrict c++ combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.J.vals[i]);
+        enzyme["result"].push_back(result.J.vals[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    {
+
+    struct BAInput input;
+    read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+                     input.X, input.w, input.obs, input.feats);
+
+    struct BAOutput result = {std::vector<double>(2 * input.p),
+                              std::vector<double>(input.p),
+                              BASparseMat(input.n, input.m, input.p)};
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<dcompute_reproj_error, dcompute_zach_weight_error>(
+          input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme aliasing c++ combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Enzyme c++ combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.J.vals[i]);
+        enzyme["result"].push_back(result.J.vals[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    {
+    struct BAInput input;
+    read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+                     input.X, input.w, input.obs, input.feats);
+
+    struct BAOutput result = {std::vector<double>(2 * input.p),
+                              std::vector<double>(input.p),
+                              BASparseMat(input.n, input.m, input.p)};
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      ba_objective_restrict(input.n, input.m, input.p, input.cams.data(),
+                            input.X.data(), input.w.data(), input.obs.data(),
+                            input.feats.data(), result.reproj_err.data(),
+                            result.w_err.data());
+      gettimeofday(&end, NULL);
+      printf("primal restrict c++ t=%0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "primal restrict c++";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.reproj_err[i]);
+        enzyme["result"].push_back(result.reproj_err[i]);
+      }
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.w_err[i]);
+        enzyme["result"].push_back(result.w_err[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    {
+    struct BAInput input;
+    read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+                     input.X, input.w, input.obs, input.feats);
+
+    struct BAOutput result = {std::vector<double>(2 * input.p),
+                              std::vector<double>(input.p),
+                              BASparseMat(input.n, input.m, input.p)};
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      ba_objective(input.n, input.m, input.p, input.cams.data(), input.X.data(),
+                   input.w.data(), input.obs.data(), input.feats.data(),
+                   result.reproj_err.data(), result.w_err.data());
+      gettimeofday(&end, NULL);
+      printf("primal aliasing c++ t=%0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "primal aliasing c++";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for(unsigned i=0; i<5; i++) {
+        printf("%f ", result.reproj_err[i]);
+        enzyme["result"].push_back(result.reproj_err[i]);
+      }
+      for(unsigned i=0; i<5; i++) {
+        printf("%f ", result.w_err[i]);
+        enzyme["result"].push_back(result.w_err[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    {
+    struct BAInput input;
+    read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+                     input.X, input.w, input.obs, input.feats);
+
+    struct BAOutput result = {std::vector<double>(2 * input.p),
+                              std::vector<double>(input.p),
+                              BASparseMat(input.n, input.m, input.p)};
+    {
+
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      rust2_unsafe_ba_objective(input.n, input.m, input.p, input.cams.data(),
+                                input.X.data(), input.w.data(),
+                                input.obs.data(), input.feats.data(),
+                                result.reproj_err.data(), result.w_err.data());
+      gettimeofday(&end, NULL);
+      printf("primal unsafe rust t=%0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "primal unsafe rust";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.reproj_err[i]);
+        enzyme["result"].push_back(result.reproj_err[i]);
+      }
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.w_err[i]);
+        enzyme["result"].push_back(result.w_err[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    {
+    struct BAInput input;
+    read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats);
+
+    struct BAOutput result = {
+        std::vector<double>(2 * input.p),
+        std::vector<double>(input.p),
+        BASparseMat(input.n, input.m, input.p)
+    };
+    {
+
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+    rust2_ba_objective(
         input.n,
         input.m,
         input.p,
@@ -424,29 +611,22 @@
         result.reproj_err.data(),
         result.w_err.data()
     );
-
-    for(unsigned i=0; i<input.p; i++) {
-        //printf("w_err[%d]=%f reproj_err[%d]=%f, reproj_err[%d]=%f\n", i, result.w_err[i], 2*i, result.reproj_err[2*i], 2*i+1, result.reproj_err[2*i+1]);
-    }
-    */
-
-    {
-      struct timeval start, end;
-      gettimeofday(&start, NULL);
-      calculate_jacobian<adept_compute_reproj_error, adept_compute_zach_weight_error>(input, result);
       gettimeofday(&end, NULL);
-      printf("Adept combined %0.6f\n", tdiff(&start, &end));
-      json adept;
-      adept["name"] = "Adept combined";
-      adept["runtime"] = tdiff(&start, &end);
+      printf("primal rust t=%0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "primal rust";
+      enzyme["runtime"] = tdiff(&start, &end);
       for(unsigned i=0; i<5; i++) {
-        printf("%f ", result.J.vals[i]);
-        adept["result"].push_back(result.J.vals[i]);
+        printf("%f ", result.reproj_err[i]);
+        enzyme["result"].push_back(result.reproj_err[i]);
+      }
+      for(unsigned i=0; i<5; i++) {
+        printf("%f ", result.w_err[i]);
+        enzyme["result"].push_back(result.w_err[i]);
       }
       printf("\n");
-      test_suite["tools"].push_back(adept);
+      test_suite["tools"].push_back(enzyme);
     }
-
     }
 
     {
@@ -460,35 +640,43 @@
         BASparseMat(input.n, input.m, input.p)
     };
 
-    //BASparseMat(this->input.n, this->input.m, this->input.p)
-
-    /*
-    ba_objective(
-        input.n,
-        input.m,
-        input.p,
-        input.cams.data(),
-        input.X.data(),
-        input.w.data(),
-        input.obs.data(),
-        input.feats.data(),
-        result.reproj_err.data(),
-        result.w_err.data()
-    );
-
-    for(unsigned i=0; i<input.p; i++) {
-        //printf("w_err[%d]=%f reproj_err[%d]=%f, reproj_err[%d]=%f\n", i, result.w_err[i], 2*i, result.reproj_err[2*i], 2*i+1, result.reproj_err[2*i+1]);
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<rust_unsafe_dcompute_reproj_error,
+                         rust_dcompute_zach_weight_error>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Enzyme unsafe rust combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = 0; i < 5; i++) {
+        printf("%f ", result.J.vals[i]);
+        enzyme["result"].push_back(result.J.vals[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
     }
-    */
+    }
+
+    {
+
+    struct BAInput input;
+    read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams,
+                     input.X, input.w, input.obs, input.feats);
+
+    struct BAOutput result = {std::vector<double>(2 * input.p),
+                              std::vector<double>(input.p),
+                              BASparseMat(input.n, input.m, input.p)};
 
     {
       struct timeval start, end;
       gettimeofday(&start, NULL);
-      calculate_jacobian<dcompute_reproj_error, dcompute_zach_weight_error>(input, result);
+      calculate_jacobian<rust_dcompute_reproj_error, rust_dcompute_zach_weight_error>(input, result);
       gettimeofday(&end, NULL);
-      printf("Enzyme combined %0.6f\n", tdiff(&start, &end));
+      printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
       json enzyme;
-      enzyme["name"] = "Enzyme combined";
+      enzyme["name"] = "Enzyme rust combined";
       enzyme["runtime"] = tdiff(&start, &end);
       for(unsigned i=0; i<5; i++) {
         printf("%f ", result.J.vals[i]);
@@ -497,8 +685,8 @@
       printf("\n");
       test_suite["tools"].push_back(enzyme);
     }
-
     }
+
     test_suite["llvm-version"] = __clang_version__;
     test_suite["mode"] = "ReverseMode";
     test_suite["batch-size"] = 1;
diff --git a/enzyme/benchmarks/ReverseMode/adbench/gmm.h b/enzyme/benchmarks/ReverseMode/adbench/gmm.h
index feef3a7..c5ec727 100644
--- a/enzyme/benchmarks/ReverseMode/adbench/gmm.h
+++ b/enzyme/benchmarks/ReverseMode/adbench/gmm.h
@@ -18,7 +18,7 @@
 using json = nlohmann::json;
 
 struct GMMInput {
-    int d, k, n;
+    size_t d, k, n;
     std::vector<double> alphas, means, icf, x;
     Wishart wishart;
 };
@@ -33,24 +33,54 @@
 };
 
 extern "C" {
-    void dgmm_objective(int d, int k, int n, const double *alphas, double *
-            alphasb, const double *means, double *meansb, const double *icf,
-            double *icfb, const double *x, Wishart wishart, double *err, double *
-            errb);
+void gmm_objective(size_t d, size_t k, size_t n, double const *alphas,
+                   double const *means, double const *icf, double const *x,
+                   Wishart wishart, double *err);
+void gmm_objective_restrict(size_t d, size_t k, size_t n, double const *alphas,
+                            double const *means, double const *icf,
+                            double const *x, Wishart wishart, double *err);
+void dgmm_objective_restrict(size_t d, size_t k, size_t n, const double *alphas,
+                             double *alphasb, const double *means,
+                             double *meansb, const double *icf, double *icfb,
+                             const double *x, Wishart wishart, double *err,
+                             double *errb);
+void dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double *alphasb,
+                    const double *means, double *meansb, const double *icf,
+                    double *icfb, const double *x, Wishart wishart, double *err,
+                    double *errb);
 
-    void gmm_objective_b(int d, int k, int n, const double *alphas, double *
-        alphasb, const double *means, double *meansb, const double *icf,
-        double *icfb, const double *x, Wishart wishart, double *err, double *
-        errb);
+void gmm_objective_b(size_t d, size_t k, size_t n, const double *alphas, double *alphasb,
+                     const double *means, double *meansb, const double *icf,
+                     double *icfb, const double *x, Wishart wishart,
+                     double *err, double *errb);
 
-    void adept_dgmm_objective(int d, int k, int n, const double *alphas, double *
-        alphasb, const double *means, double *meansb, const double *icf,
-        double *icfb, const double *x, Wishart wishart, double *err, double *
-        errb);
+void adept_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+                          double *alphasb, const double *means, double *meansb,
+                          const double *icf, double *icfb, const double *x,
+                          Wishart wishart, double *err, double *errb);
+
+void rust_unsafe_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+                                double *alphasb, const double *means,
+                                double *meansb, const double *icf, double *icfb,
+                                const double *x, Wishart &wishart, double *err,
+                                double *errb);
+
+void rust_unsafe_gmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+                               const double *means, const double *icf,
+                               const double *x, Wishart &wishart, double *err);
+
+void rust_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+                         double *alphasb, const double *means, double *meansb,
+                         const double *icf, double *icfb, const double *x,
+                         Wishart &wishart, double *err, double *errb);
+
+void rust_gmm_objective(size_t d, size_t k, size_t n, const double *alphas,
+                        const double *means, const double *icf, const double *x,
+                        Wishart &wishart, double *err);
 }
 
 void read_gmm_instance(const string& fn,
-    int* d, int* k, int* n,
+    size_t* d, size_t* k, size_t* n,
     vector<double>& alphas,
     vector<double>& means,
     vector<double>& icf,
@@ -65,32 +95,32 @@
         exit(1);
     }
 
-    fscanf(fid, "%i %i %i", d, k, n);
+    fscanf(fid, "%zu %zu %zu", d, k, n);
 
-    int d_ = *d, k_ = *k, n_ = *n;
+    size_t d_ = *d, k_ = *k, n_ = *n;
 
-    int icf_sz = d_ * (d_ + 1) / 2;
+    size_t icf_sz = d_ * (d_ + 1) / 2;
     alphas.resize(k_);
     means.resize(d_ * k_);
     icf.resize(icf_sz * k_);
     x.resize(d_ * n_);
 
-    for (int i = 0; i < k_; i++)
+    for (size_t i = 0; i < k_; i++)
     {
         fscanf(fid, "%lf", &alphas[i]);
     }
 
-    for (int i = 0; i < k_; i++)
+    for (size_t i = 0; i < k_; i++)
     {
-        for (int j = 0; j < d_; j++)
+        for (size_t j = 0; j < d_; j++)
         {
             fscanf(fid, "%lf", &means[i * d_ + j]);
         }
     }
 
-    for (int i = 0; i < k_; i++)
+    for (size_t i = 0; i < k_; i++)
     {
-        for (int j = 0; j < icf_sz; j++)
+        for (size_t j = 0; j < icf_sz; j++)
         {
             fscanf(fid, "%lf", &icf[i * icf_sz + j]);
         }
@@ -98,20 +128,20 @@
 
     if (replicate_point)
     {
-        for (int j = 0; j < d_; j++)
+        for (size_t j = 0; j < d_; j++)
         {
             fscanf(fid, "%lf", &x[j]);
         }
-        for (int i = 0; i < n_; i++)
+        for (size_t i = 0; i < n_; i++)
         {
             memcpy(&x[i * d_], &x[0], d_ * sizeof(double));
         }
     }
     else
     {
-        for (int i = 0; i < n_; i++)
+        for (size_t i = 0; i < n_; i++)
         {
-            for (int j = 0; j < d_; j++)
+            for (size_t j = 0; j < d_; j++)
             {
                 fscanf(fid, "%lf", &x[i * d_ + j]);
             }
@@ -123,10 +153,7 @@
     fclose(fid);
 }
 
-typedef void(*deriv_t)(int d, int k, int n, const double *alphas, double *alphasb, const double *means, double *meansb, const double *icf,
-            double *icfb, const double *x, Wishart wishart, double *err, double *errb);
-
-template<deriv_t deriv>
+template<auto  deriv>
 void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result)
 {
     double* alphas_gradient_part = result.gradient.data();
@@ -159,13 +186,38 @@
     );
 }
 
+template<auto  deriv>
+double primal(struct GMMInput &input)
+{
+    double tmp = 0.0;       // stores fictive result
+                            // (Tapenade doesn't calculate an original function in reverse mode)
+    deriv(
+        input.d,
+        input.k,
+        input.n,
+        input.alphas.data(),
+        input.means.data(),
+        input.icf.data(),
+        input.x.data(),
+        input.wishart,
+        &tmp
+    );
+    return tmp;
+}
+
 int main(const int argc, const char* argv[]) {
     printf("starting main\n");
 
     const auto replicate_point = (argc > 9 && string(argv[9]) == "-rep");
     const GMMParameters params = { replicate_point };
 
-    std::vector<std::string> paths;// = { "1k/gmm_d10_K100.txt" };
+    std::vector<std::string> paths = { "10k/gmm_d10_K200.txt" };
+
+    //getTests(paths, "data/1k", "1k/");
+    if (std::getenv("BENCH_LARGE")) {
+      getTests(paths, "data/2.5k", "2.5k/");
+      getTests(paths, "data/10k", "10k/");
+    }
 
     getTests(paths, "data/1k", "1k/");
     if (std::getenv("BENCH_LARGE")) {
@@ -188,7 +240,7 @@
     read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
         input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
 
-    int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+    size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
 
     struct GMMOutput result = { 0, std::vector<double>(Jcols) };
 
@@ -218,49 +270,82 @@
     read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
         input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
 
-    int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+    size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
 
     struct GMMOutput result = { 0, std::vector<double>(Jcols) };
 
-    try {
-      struct timeval start, end;
-      gettimeofday(&start, NULL);
-      calculate_jacobian<adept_dgmm_objective>(input, result);
-      gettimeofday(&end, NULL);
-      printf("Adept combined %0.6f\n", tdiff(&start, &end));
-      json adept;
-      adept["name"] = "Adept combined";
-      adept["runtime"] = tdiff(&start, &end);
-      for (unsigned i = result.gradient.size() - 5;
-           i < result.gradient.size(); i++) {
-        printf("%f ", result.gradient[i]);
-        adept["result"].push_back(result.gradient[i]);
+    //if (0) {
+      try {
+        struct timeval start, end;
+        gettimeofday(&start, NULL);
+        calculate_jacobian<adept_dgmm_objective>(input, result);
+        gettimeofday(&end, NULL);
+        printf("Adept combined %0.6f\n", tdiff(&start, &end));
+        json adept;
+        adept["name"] = "Adept combined";
+        adept["runtime"] = tdiff(&start, &end);
+        for (unsigned i = result.gradient.size() - 5;
+             i < result.gradient.size(); i++) {
+          printf("%f ", result.gradient[i]);
+          adept["result"].push_back(result.gradient[i]);
+        }
+        printf("\n");
+        test_suite["tools"].push_back(adept);
+      } catch (std::bad_alloc) {
+        printf("Adept combined 88888888 ooms\n");
       }
-      printf("\n");
-      test_suite["tools"].push_back(adept);
-    } catch(std::bad_alloc) {
-       printf("Adept combined 88888888 ooms\n");
+    //}
     }
 
-    }
-
+    for (size_t i = 0; i < 5; i++)
     {
 
     struct GMMInput input;
     read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
         input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
 
-    int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+    size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
 
     struct GMMOutput result = { 0, std::vector<double>(Jcols) };
 
     {
       struct timeval start, end;
       gettimeofday(&start, NULL);
+      calculate_jacobian<dgmm_objective_restrict>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme c++ restrict combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Enzyme restrict combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+           i++) {
+        printf("%f ", result.gradient[i]);
+        enzyme["result"].push_back(result.gradient[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    {
+
+    struct GMMInput input;
+    read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+                      input.alphas, input.means, input.icf, input.x,
+                      input.wishart, params.replicate_point);
+
+    size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+    struct GMMOutput result = {0, std::vector<double>(Jcols)};
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
       calculate_jacobian<dgmm_objective>(input, result);
       gettimeofday(&end, NULL);
+      printf("Enzyme c++ mayalias combined %0.6f\n", tdiff(&start, &end));
       json enzyme;
-      enzyme["name"] = "Enzyme combined";
+      enzyme["name"] = "Enzyme mayalias combined";
       enzyme["runtime"] = tdiff(&start, &end);
       for (unsigned i = result.gradient.size() - 5;
            i < result.gradient.size(); i++) {
@@ -270,8 +355,132 @@
       printf("\n");
       test_suite["tools"].push_back(enzyme);
     }
-
     }
+    
+    {
+
+    struct GMMInput input;
+    read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+                      input.alphas, input.means, input.icf, input.x,
+                      input.wishart, params.replicate_point);
+
+    size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+    struct GMMOutput result = {0, std::vector<double>(Jcols)};
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<rust_unsafe_dgmm_objective>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Rust unsafe Enzyme combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+           i++) {
+        printf("%f ", result.gradient[i]);
+        enzyme["result"].push_back(result.gradient[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    for (size_t i = 0; i < 5; i++)
+    {
+
+    struct GMMInput input;
+    read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+                      input.alphas, input.means, input.icf, input.x,
+                      input.wishart, params.replicate_point);
+
+    size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+    struct GMMOutput result = {0, std::vector<double>(Jcols)};
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<rust_dgmm_objective>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Rust Enzyme combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = result.gradient.size() - 5;
+           i < result.gradient.size(); i++) {
+        printf("%f ", result.gradient[i]);
+        enzyme["result"].push_back(result.gradient[i]);
+      }
+      printf("\n");
+      test_suite["tools"].push_back(enzyme);
+    }
+    }
+
+    {
+
+    struct GMMInput input;
+    read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
+        input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);
+
+    size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;
+
+    struct GMMOutput result = { 0, std::vector<double>(Jcols) };
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      auto res = primal<gmm_objective>(input);
+      gettimeofday(&end, NULL);
+      printf("c++ primal mayalias combined t=%0.6f, err=%f\n",
+             tdiff(&start, &end), res);
+
+      json primal;
+      primal["name"] = "C++ primal mayalias";
+      primal["runtime"] = tdiff(&start, &end);
+      primal["result"].push_back(res);
+      test_suite["tools"].push_back(primal);
+    }
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      auto res = primal<gmm_objective_restrict>(input);
+      gettimeofday(&end, NULL);
+      printf("c++ primal restrict combined t=%0.6f, err=%f\n",
+             tdiff(&start, &end), res);
+
+      json primal;
+      primal["name"] = "C++ primal restrict";
+      primal["runtime"] = tdiff(&start, &end);
+      primal["result"].push_back(res);
+      test_suite["tools"].push_back(primal);
+    }
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      auto res = primal<rust_unsafe_gmm_objective>(input);
+      gettimeofday(&end, NULL);
+      printf("rust unsafe primal combined t=%0.6f, err=%f\n",
+             tdiff(&start, &end), res);
+      json primal;
+      primal["name"] = "Rust unsafe primal";
+      primal["runtime"] = tdiff(&start, &end);
+      primal["result"].push_back(res);
+      test_suite["tools"].push_back(primal);
+    }
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      auto res = primal<rust_gmm_objective>(input);
+      gettimeofday(&end, NULL);
+      printf("rust primal combined t=%0.6f, err=%f\n", tdiff(&start, &end), res);
+      json primal;
+      primal["name"] = "Rust primal";
+      primal["runtime"] = tdiff(&start, &end);
+      primal["result"].push_back(res);
+      test_suite["tools"].push_back(primal);
+    }
+    }
+
     test_suite["llvm-version"] = __clang_version__;
     test_suite["mode"] = "ReverseMode";
     test_suite["batch-size"] = 1;
diff --git a/enzyme/benchmarks/ReverseMode/adbench/lstm.h b/enzyme/benchmarks/ReverseMode/adbench/lstm.h
index e6d1330..4f99841 100644
--- a/enzyme/benchmarks/ReverseMode/adbench/lstm.h
+++ b/enzyme/benchmarks/ReverseMode/adbench/lstm.h
@@ -34,37 +34,56 @@
 };
 
 extern "C" {
-    void dlstm_objective(
-        int l,
-        int c,
-        int b,
-        double const* main_params,
-        double* dmain_params,
-        double const* extra_params,
-        double* dextra_params,
-        double* state,
-        double const* sequence,
-        double* loss,
-        double* dloss
-    );
+void rust_unsafe_dlstm_objective(int l, int c, int b, double const *main_params,
+                                 double *dmain_params,
+                                 double const *extra_params,
+                                 double *dextra_params, double *state,
+                                 double const *sequence, double *loss,
+                                 double *dloss);
 
-    void lstm_objective_b(int l, int c, int b, const double *main_params, double *
-        main_paramsb, const double *extra_params, double *extra_paramsb,
-        double *state, const double *sequence, double *loss, double *lossb);
+void rust_unsafe_lstm_objective(int l, int c, int b, double const *main_params,
+                                double const *extra_params, double *state,
+                                double const *sequence, double *loss);
 
-    void adept_dlstm_objective(
-        int l,
-        int c,
-        int b,
-        double const* main_params,
-        double* dmain_params,
-        double const* extra_params,
-        double* dextra_params,
-        double* state,
-        double const* sequence,
-        double* loss,
-        double* dloss
-    );
+void rust_safe_lstm_objective(int l, int c, int b, double const *main_params,
+                              double const *extra_params, double *state,
+                              double const *sequence, double *loss);
+
+void cxx_restrict_lstm_objective(int l, int c, int b, double const *main_params,
+                                 double const *extra_params, double *state,
+                                 double const *sequence, double *loss);
+
+void cxx_mayalias_lstm_objective(int l, int c, int b, double const *main_params,
+                                 double const *extra_params, double *state,
+                                 double const *sequence, double *loss);
+
+void rust_safe_dlstm_objective(int l, int c, int b, double const *main_params,
+                               double *dmain_params, double const *extra_params,
+                               double *dextra_params, double *state,
+                               double const *sequence, double *loss,
+                               double *dloss);
+
+void dlstm_objective_mayalias(int l, int c, int b, double const *main_params,
+                              double *dmain_params, double const *extra_params,
+                              double *dextra_params, double *state,
+                              double const *sequence, double *loss,
+                              double *dloss);
+
+void dlstm_objective_restrict(int l, int c, int b, double const *main_params,
+                              double *dmain_params, double const *extra_params,
+                              double *dextra_params, double *state,
+                              double const *sequence, double *loss,
+                              double *dloss);
+
+void lstm_objective_b(int l, int c, int b, const double *main_params,
+                      double *main_paramsb, const double *extra_params,
+                      double *extra_paramsb, double *state,
+                      const double *sequence, double *loss, double *lossb);
+
+void adept_dlstm_objective(int l, int c, int b, double const *main_params,
+                           double *dmain_params, double const *extra_params,
+                           double *dextra_params, double *state,
+                           double const *sequence, double *loss, double *dloss);
 }
 
 void read_lstm_instance(const string& fn,
@@ -177,10 +196,55 @@
     }
 }
 
+double calculate_mayalias_primal(struct LSTMInput &input) {
+    double loss = 0.0;
+    for (int i = 0; i < 100; i++) {
+        cxx_mayalias_lstm_objective(
+            input.l, input.c, input.b, input.main_params.data(),
+            input.extra_params.data(), input.state.data(),
+            input.sequence.data(), &loss);
+    }
+    return loss;
+}
+
+double calculate_restrict_primal(struct LSTMInput &input) {
+    double loss = 0.0;
+    for (int i = 0; i < 100; i++) {
+        cxx_restrict_lstm_objective(
+            input.l, input.c, input.b, input.main_params.data(),
+            input.extra_params.data(), input.state.data(),
+            input.sequence.data(), &loss);
+    }
+    return loss;
+}
+
+double calculate_unsafe_primal(struct LSTMInput &input) {
+    double loss = 0.0;
+    for (int i = 0; i < 100; i++) {
+        rust_unsafe_lstm_objective(
+            input.l, input.c, input.b, input.main_params.data(),
+            input.extra_params.data(), input.state.data(),
+            input.sequence.data(), &loss);
+    }
+    return loss;
+}
+
+double calculate_safe_primal(struct LSTMInput &input) {
+    double loss = 0.0;
+    for (int i = 0; i < 100; i++) {
+        rust_safe_lstm_objective(input.l, input.c, input.b,
+                                 input.main_params.data(),
+                                 input.extra_params.data(), input.state.data(),
+                                 input.sequence.data(), &loss);
+    }
+    return loss;
+}
+
 int main(const int argc, const char* argv[]) {
     printf("starting main\n");
 
-    std::vector<std::string> paths = { "lstm_l2_c1024.txt", "lstm_l4_c1024.txt", "lstm_l2_c4096.txt", "lstm_l4_c4096.txt" };
+    //std::vector<std::string> paths = { "lstm_l2_c1024.txt", "lstm_l4_c1024.txt", "lstm_l2_c4096.txt", "lstm_l4_c4096.txt" };
+    std::vector<std::string> paths = { "lstm_l4_c4096.txt" };
     
     std::ofstream jsonfile("results.json", std::ofstream::trunc);
     json test_results;
@@ -227,16 +291,17 @@
 
     {
 
-    struct LSTMInput input = {};
+     struct LSTMInput input = {};
 
     // Read instance
-    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, input.main_params, input.extra_params, input.state,
-                       input.sequence);
+     read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+     input.main_params, input.extra_params, input.state,
+                        input.sequence);
 
-    std::vector<double> state = std::vector<double>(input.state.size());
+     std::vector<double> state = std::vector<double>(input.state.size());
 
-    int Jcols = 8 * input.l * input.b + 3 * input.b;
-    struct LSTMOutput result = { 0, std::vector<double>(Jcols) };
+     int Jcols = 8 * input.l * input.b + 3 * input.b;
+     struct LSTMOutput result = { 0, std::vector<double>(Jcols) };
 
     {
       struct timeval start, end;
@@ -274,14 +339,81 @@
     {
       struct timeval start, end;
       gettimeofday(&start, NULL);
-      calculate_jacobian<dlstm_objective>(input, result);
+      calculate_jacobian<dlstm_objective_restrict>(input, result);
       gettimeofday(&end, NULL);
-      printf("Enzyme combined %0.6f\n", tdiff(&start, &end));
+      printf("Enzyme restrict combined %0.6f\n", tdiff(&start, &end));
       json enzyme;
-       enzyme["name"] = "Enzyme combined";
-       enzyme["runtime"] = tdiff(&start, &end);
-       for (unsigned i = result.gradient.size() - 5;
-            i < result.gradient.size(); i++) {
+      enzyme["name"] = "Enzyme restrict combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+           i++) {
+        printf("%f ", result.gradient[i]);
+        enzyme["result"].push_back(result.gradient[i]);
+      }
+      test_suite["tools"].push_back(enzyme);
+
+      printf("\n");
+    }
+    }
+
+    {
+
+    struct LSTMInput input = {};
+
+    // Read instance
+    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+                       input.main_params, input.extra_params, input.state,
+                       input.sequence);
+
+    std::vector<double> state = std::vector<double>(input.state.size());
+
+    int Jcols = 8 * input.l * input.b + 3 * input.b;
+    struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<dlstm_objective_mayalias>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme mayalias combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Enzyme mayalias combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+           i++) {
+        printf("%f ", result.gradient[i]);
+        enzyme["result"].push_back(result.gradient[i]);
+       }
+       test_suite["tools"].push_back(enzyme);
+       
+       printf("\n");
+    }
+    }
+
+    {
+
+    struct LSTMInput input = {};
+
+    // Read instance
+    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, input.main_params, input.extra_params, input.state,
+                       input.sequence);
+
+    std::vector<double> state = std::vector<double>(input.state.size());
+
+    int Jcols = 8 * input.l * input.b + 3 * input.b;
+    struct LSTMOutput result = { 0, std::vector<double>(Jcols) };
+
+    {
+      struct timeval start, end;
+      gettimeofday(&start, NULL);
+      calculate_jacobian<rust_safe_dlstm_objective>(input, result);
+      gettimeofday(&end, NULL);
+      printf("Enzyme (safe Rust) combined %0.6f\n", tdiff(&start, &end));
+      json enzyme;
+      enzyme["name"] = "Enzyme (safe Rust) combined";
+      enzyme["runtime"] = tdiff(&start, &end);
+      for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+           i++) {
          printf("%f ", result.gradient[i]);
          enzyme["result"].push_back(result.gradient[i]);
        }
@@ -291,6 +423,161 @@
     }
 
     }
+
+    {
+
+    struct LSTMInput input = {};
+
+    // Read instance
+    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+                       input.main_params, input.extra_params, input.state,
+                       input.sequence);
+
+    std::vector<double> state = std::vector<double>(input.state.size());
+
+    int Jcols = 8 * input.l * input.b + 3 * input.b;
+    struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+    {
+       struct timeval start, end;
+       gettimeofday(&start, NULL);
+       calculate_jacobian<rust_unsafe_dlstm_objective>(input, result);
+       gettimeofday(&end, NULL);
+       printf("Enzyme (unsafe Rust) combined %0.6f\n", tdiff(&start, &end));
+       json enzyme;
+       enzyme["name"] = "Enzyme (unsafe Rust) combined";
+       enzyme["runtime"] = tdiff(&start, &end);
+       for (unsigned i = result.gradient.size() - 5; i < result.gradient.size();
+            i++) {
+         printf("%f ", result.gradient[i]);
+         enzyme["result"].push_back(result.gradient[i]);
+       }
+       test_suite["tools"].push_back(enzyme);
+
+       printf("\n");
+    }
+    }
+    {
+
+    struct LSTMInput input = {};
+
+    // Read instance
+    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+                       input.main_params, input.extra_params, input.state,
+                       input.sequence);
+
+    std::vector<double> state = std::vector<double>(input.state.size());
+
+    int Jcols = 8 * input.l * input.b + 3 * input.b;
+    struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+    {
+       struct timeval start, end;
+       gettimeofday(&start, NULL);
+       double res = calculate_mayalias_primal(input);
+       gettimeofday(&end, NULL);
+       printf("C++ mayalias primal %0.6f\n", tdiff(&start, &end));
+       json enzyme;
+       enzyme["name"] = "C++ mayalias primal";
+       enzyme["runtime"] = tdiff(&start, &end);
+       printf("%f ", res);
+       enzyme["result"].push_back(res);
+       test_suite["tools"].push_back(enzyme);
+
+       printf("\n");
+    }
+    }
+    {
+
+    struct LSTMInput input = {};
+
+    // Read instance
+    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+                       input.main_params, input.extra_params, input.state,
+                       input.sequence);
+
+    std::vector<double> state = std::vector<double>(input.state.size());
+
+    int Jcols = 8 * input.l * input.b + 3 * input.b;
+    struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+    {
+       struct timeval start, end;
+       gettimeofday(&start, NULL);
+       double res = calculate_restrict_primal(input);
+       gettimeofday(&end, NULL);
+       printf("C++ restrict primal %0.6f\n", tdiff(&start, &end));
+       json enzyme;
+       enzyme["name"] = "C++ restrict primal";
+       enzyme["runtime"] = tdiff(&start, &end);
+       printf("%f ", res);
+       enzyme["result"].push_back(res);
+       test_suite["tools"].push_back(enzyme);
+
+       printf("\n");
+    }
+    }
+    {
+
+    struct LSTMInput input = {};
+
+    // Read instance
+    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+                       input.main_params, input.extra_params, input.state,
+                       input.sequence);
+
+    std::vector<double> state = std::vector<double>(input.state.size());
+
+    int Jcols = 8 * input.l * input.b + 3 * input.b;
+    struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+    {
+       struct timeval start, end;
+       gettimeofday(&start, NULL);
+       double res =calculate_unsafe_primal(input);
+       gettimeofday(&end, NULL);
+       printf("Enzyme (unsafe Rust) primal %0.6f\n", tdiff(&start, &end));
+       json enzyme;
+       enzyme["name"] = "Enzyme (unsafe Rust) primal";
+       enzyme["runtime"] = tdiff(&start, &end);
+       printf("%f ", res);
+       enzyme["result"].push_back(res);
+       test_suite["tools"].push_back(enzyme);
+
+       printf("\n");
+    }
+    }
+    {
+
+    struct LSTMInput input = {};
+
+    // Read instance
+    read_lstm_instance("data/" + path, &input.l, &input.c, &input.b,
+                       input.main_params, input.extra_params, input.state,
+                       input.sequence);
+
+    std::vector<double> state = std::vector<double>(input.state.size());
+
+    int Jcols = 8 * input.l * input.b + 3 * input.b;
+    struct LSTMOutput result = {0, std::vector<double>(Jcols)};
+
+    {
+       struct timeval start, end;
+       gettimeofday(&start, NULL);
+       double res = calculate_safe_primal(input);
+       gettimeofday(&end, NULL);
+       printf("Enzyme (safe Rust) primal %0.6f\n", tdiff(&start, &end));
+       json enzyme;
+       enzyme["name"] = "Enzyme (safe Rust) primal";
+       enzyme["runtime"] = tdiff(&start, &end);
+       printf("%f ", res);
+       enzyme["result"].push_back(res);
+       test_suite["tools"].push_back(enzyme);
+
+       printf("\n");
+    }
+    }
+
     test_suite["llvm-version"] = __clang_version__;
     test_suite["mode"] = "ReverseMode";
     test_suite["batch-size"] = 1;
diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.lock b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock
new file mode 100644
index 0000000..74e2768
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock
@@ -0,0 +1,16 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "bars"
+version = "0.1.0"
+dependencies = [
+ "libm",
+]
+
+[[package]]
+name = "libm"
+version = "0.2.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.toml b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml
new file mode 100644
index 0000000..4bc9c21
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml
@@ -0,0 +1,23 @@
+[package]
+name = "bars"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+crate-type = ["cdylib"]
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
+
+[dependencies]
+libm = { version = "0.2.8", optional = true }
diff --git a/enzyme/benchmarks/ReverseMode/ba/Makefile.make b/enzyme/benchmarks/ReverseMode/ba/Makefile.make
index b7f013d..50ab0cf 100644
--- a/enzyme/benchmarks/ReverseMode/ba/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/ba/Makefile.make
@@ -6,6 +6,10 @@
 
 clean:
 	rm -f *.ll *.o results.txt results.json
+	cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a: src/lib.rs Cargo.toml
+	RUSTFLAGS="-Z autodiff=Enable" cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm
 
 %-unopt.ll: %.cpp
 	clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,8 +20,8 @@
 %-opt.ll: %-raw.ll
 	opt $^ -o $@ -S
 
-ba.o: ba-opt.ll
+ba.o: ba-opt.ll $(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a
 	clang++ $(BENCH) -pthread -O2 $^ -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o $@ $(BENCHLINK) -lpthread -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
 
 results.json: ba.o
-	./$^
+	numactl -C 1 ./$^
diff --git a/enzyme/benchmarks/ReverseMode/ba/ba.cpp b/enzyme/benchmarks/ReverseMode/ba/ba.cpp
index b71e05a..602af73 100644
--- a/enzyme/benchmarks/ReverseMode/ba/ba.cpp
+++ b/enzyme/benchmarks/ReverseMode/ba/ba.cpp
@@ -43,17 +43,13 @@
     return res;
 }
 
-
-
-void cross(double const* a, double const* b, double* out)
-{
+void cross_restrict(double const *__restrict a, double const *__restrict b,
+                    double *__restrict out) {
     out[0] = a[1] * b[2] - a[2] * b[1];
     out[1] = a[2] * b[0] - a[0] * b[2];
     out[2] = a[0] * b[1] - a[1] * b[0];
 }
 
-
-
 /* ===================================================================== */
 /*                               MAIN LOGIC                              */
 /* ===================================================================== */
@@ -68,8 +64,9 @@
 //  n = w / theta;
 //  n_x = au_cross_matrix(n);
 //  R = eye(3) + n_x*sin(theta) + n_x*n_x*(1 - cos(theta));
-void rodrigues_rotate_point(double const* __restrict rot, double const* __restrict pt, double *__restrict rotatedPt)
-{
+void rodrigues_rotate_point_restrict(double const *__restrict rot,
+                                     double const *__restrict pt,
+                                     double *__restrict rotatedPt) {
     int i;
     double sqtheta = sqsum(3, rot);
     if (sqtheta != 0)
@@ -87,7 +84,7 @@
             w[i] = rot[i] * theta_inverse;
         }
 
-        cross(w, pt, w_cross_pt);
+        cross_restrict(w, pt, w_cross_pt);
 
         tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) *
             (1. - costheta);
@@ -100,7 +97,7 @@
     else
     {
         double rot_cross_pt[3];
-        cross(rot, pt, rot_cross_pt);
+        cross_restrict(rot, pt, rot_cross_pt);
 
         for (i = 0; i < 3; i++)
         {
@@ -109,8 +106,6 @@
     }
 }
 
-
-
 void radial_distort(double const* rad_params, double *proj)
 {
     double rsq, L;
@@ -120,10 +115,8 @@
     proj[1] = proj[1] * L;
 }
 
-
-
-void project(double const* __restrict cam, double const* __restrict X, double* __restrict proj)
-{
+void project_restrict(double const *__restrict cam, double const *__restrict X,
+                      double *__restrict proj) {
     double const* C = &cam[3];
     double Xo[3], Xcam[3];
 
@@ -131,7 +124,7 @@
     Xo[1] = X[1] - C[1];
     Xo[2] = X[2] - C[2];
 
-    rodrigues_rotate_point(&cam[0], Xo, Xcam);
+    rodrigues_rotate_point_restrict(&cam[0], Xo, Xcam);
 
     proj[0] = Xcam[0] / Xcam[2];
     proj[1] = Xcam[1] / Xcam[2];
@@ -142,8 +135,6 @@
     proj[1] = proj[1] * cam[6] + cam[8];
 }
 
-
-
 // cam: 11 camera in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
 //            r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
 //            [C1 C2 C3]' is the camera center
@@ -158,30 +149,23 @@
 // distorted = radial_distort(projective2euclidean(Xcam), radial_parameters)
 // proj = distorted * f + principal_point
 // err = sqsum(proj - measurement)
-void compute_reproj_error(
-    double const* __restrict cam,
-    double const* __restrict X,
-    double const* __restrict w,
-    double const* __restrict feat,
-    double * __restrict err
-)
-{
+void compute_reproj_error_restrict(double const *__restrict cam,
+                                   double const *__restrict X,
+                                   double const *__restrict w,
+                                   double const *__restrict feat,
+                                   double *__restrict err) {
     double proj[2];
-    project(cam, X, proj);
+    project_restrict(cam, X, proj);
 
     err[0] = (*w)*(proj[0] - feat[0]);
     err[1] = (*w)*(proj[1] - feat[1]);
 }
 
-
-
-void compute_zach_weight_error(double const* w, double* err)
-{
+void compute_zach_weight_error_restrict(double const *__restrict w,
+                                        double *__restrict err) {
     *err = 1 - (*w)*(*w);
 }
 
-
-
 // n number of cameras
 // m number of points
 // p number of observations
@@ -196,36 +180,23 @@
 // feats: 2*p features (x,y coordinates corresponding to observations)
 // reproj_err: 2*p errors of observations
 // w_err: p weight "error" terms
-void ba_objective(
-    int n,
-    int m,
-    int p,
-    double const* cams,
-    double const* X,
-    double const* w,
-    int const* obs,
-    double const* feats,
-    double* reproj_err,
-    double* w_err
-)
-{
+void ba_objective_restrict(int n, int m, int p, double const *cams,
+                           double const *X, double const *w, int const *obs,
+                           double const *feats, double *reproj_err,
+                           double *w_err) {
     int i;
     for (i = 0; i < p; i++)
     {
         int camIdx = obs[i * 2 + 0];
         int ptIdx = obs[i * 2 + 1];
-        compute_reproj_error(
-            &cams[camIdx * BA_NCAMPARAMS],
-            &X[ptIdx * 3],
-            &w[i],
-            &feats[i * 2],
-            &reproj_err[2 * i]
-        );
+        compute_reproj_error_restrict(&cams[camIdx * BA_NCAMPARAMS],
+                                      &X[ptIdx * 3], &w[i], &feats[i * 2],
+                                      &reproj_err[2 * i]);
     }
 
     for (i = 0; i < p; i++)
     {
-        compute_zach_weight_error(&w[i], &w_err[i]);
+        compute_zach_weight_error_restrict(&w[i], &w_err[i]);
     }
 }
 
@@ -234,32 +205,21 @@
 extern int enzyme_dupnoneed;
 void __enzyme_autodiff(...) noexcept;
 
-void dcompute_reproj_error(
-    double const* cam,
-    double * dcam,
-    double const* X,
-    double * dX,
-    double const* w,
-    double * wb,
-    double const* feat,
-    double *err,
-    double *derr
-)
-{
-    __enzyme_autodiff(compute_reproj_error,
-            enzyme_dup, cam, dcam,
-            enzyme_dup, X, dX,
-            enzyme_dup, w, wb,
-            enzyme_const, feat,
-            enzyme_dupnoneed, err, derr);
+void dcompute_reproj_error_restrict(double const *cam, double *dcam,
+                                    double const *X, double *dX,
+                                    double const *w, double *wb,
+                                    double const *feat, double *err,
+                                    double *derr) {
+    __enzyme_autodiff(compute_reproj_error_restrict, enzyme_dup, cam, dcam,
+                      enzyme_dup, X, dX, enzyme_dup, w, wb, enzyme_const, feat,
+                      enzyme_dupnoneed, err, derr);
 }
 
-void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr) {
-    __enzyme_autodiff(compute_zach_weight_error,
-            enzyme_dup, w, dw,
-            enzyme_dupnoneed, err, derr);
+void dcompute_zach_weight_error_restrict(double const *w, double *dw,
+                                         double *err, double *derr) {
+    __enzyme_autodiff(compute_zach_weight_error_restrict, enzyme_dup, w, dw,
+                      enzyme_dupnoneed, err, derr);
 }
-
 }
 
 
@@ -911,3 +871,5 @@
 
       *dw = aw.get_gradient();
 }
+
+#include "ba_mayalias.h"
diff --git a/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h b/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h
new file mode 100644
index 0000000..25197b5
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h
@@ -0,0 +1,198 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+extern "C" {
+
+/* ===================================================================== */
+/*                                UTILS                                  */
+/* ===================================================================== */
+
+void cross(double const *a, double const *b, double *out) {
+  out[0] = a[1] * b[2] - a[2] * b[1];
+  out[1] = a[2] * b[0] - a[0] * b[2];
+  out[2] = a[0] * b[1] - a[1] * b[0];
+}
+
+/* ===================================================================== */
+/*                               MAIN LOGIC                              */
+/* ===================================================================== */
+
+void compute_zach_weight_error(double const *w, double *err) {
+  *err = 1 - (*w) * (*w);
+}
+
+// rot: 3 rotation parameters
+// pt: 3 point to be rotated
+// rotatedPt: 3 rotated point
+// this is an efficient evaluation (part of
+// the Ceres implementation)
+// easy to understand calculation in matlab:
+//  theta = sqrt(sum(w. ^ 2));
+//  n = w / theta;
+//  n_x = au_cross_matrix(n);
+//  R = eye(3) + n_x*sin(theta) + n_x*n_x*(1 - cos(theta));
+void rodrigues_rotate_point(double const *rot, double const *pt,
+                            double *rotatedPt) {
+    int i;
+    double sqtheta = sqsum(3, rot);
+    if (sqtheta != 0)
+    {
+        double theta, costheta, sintheta, theta_inverse;
+        double w[3], w_cross_pt[3], tmp;
+
+        theta = sqrt(sqtheta);
+        costheta = cos(theta);
+        sintheta = sin(theta);
+        theta_inverse = 1.0 / theta;
+
+        for (i = 0; i < 3; i++)
+        {
+            w[i] = rot[i] * theta_inverse;
+        }
+
+        cross(w, pt, w_cross_pt);
+
+        tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) *
+            (1. - costheta);
+
+        for (i = 0; i < 3; i++)
+        {
+            rotatedPt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
+        }
+    }
+    else
+    {
+        double rot_cross_pt[3];
+        cross(rot, pt, rot_cross_pt);
+
+        for (i = 0; i < 3; i++)
+        {
+            rotatedPt[i] = pt[i] + rot_cross_pt[i];
+        }
+    }
+}
+
+void project(double const *cam, double const *X, double *proj) {
+    double const* C = &cam[3];
+    double Xo[3], Xcam[3];
+
+    Xo[0] = X[0] - C[0];
+    Xo[1] = X[1] - C[1];
+    Xo[2] = X[2] - C[2];
+
+    rodrigues_rotate_point(&cam[0], Xo, Xcam);
+
+    proj[0] = Xcam[0] / Xcam[2];
+    proj[1] = Xcam[1] / Xcam[2];
+
+    radial_distort(&cam[9], proj);
+
+    proj[0] = proj[0] * cam[6] + cam[7];
+    proj[1] = proj[1] * cam[6] + cam[8];
+}
+
+// cam: 11 camera in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
+//            r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
+//            [C1 C2 C3]' is the camera center
+//            f is the focal length in pixels
+//            [u0 v0]' is the principal point
+//            k1, k2 are radial distortion parameters
+// X: 3 point
+// feats: 2 feature (x,y coordinates)
+// reproj_err: 2
+// projection function:
+// Xcam = R * (X - C)
+// distorted = radial_distort(projective2euclidean(Xcam), radial_parameters)
+// proj = distorted * f + principal_point
+// err = sqsum(proj - measurement)
+void compute_reproj_error(double const *cam, double const *X, double const *w,
+                          double const *feat, double *err) {
+    double proj[2];
+    project(cam, X, proj);
+
+    err[0] = (*w)*(proj[0] - feat[0]);
+    err[1] = (*w)*(proj[1] - feat[1]);
+}
+
+
+
+
+// n number of cameras
+// m number of points
+// p number of observations
+// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
+//            r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
+//            [C1 C2 C3]' is the camera center
+//            f is the focal length in pixels
+//            [u0 v0]' is the principal point
+//            k1, k2 are radial distortion parameters
+// X: 3*m points
+// obs: 2*p observations (pairs cameraIdx, pointIdx)
+// feats: 2*p features (x,y coordinates corresponding to observations)
+// reproj_err: 2*p errors of observations
+// w_err: p weight "error" terms
+void ba_objective(
+    int n,
+    int m,
+    int p,
+    double const* cams,
+    double const* X,
+    double const* w,
+    int const* obs,
+    double const* feats,
+    double* reproj_err,
+    double* w_err
+)
+{
+    int i;
+    for (i = 0; i < p; i++)
+    {
+        int camIdx = obs[i * 2 + 0];
+        int ptIdx = obs[i * 2 + 1];
+        compute_reproj_error(
+            &cams[camIdx * BA_NCAMPARAMS],
+            &X[ptIdx * 3],
+            &w[i],
+            &feats[i * 2],
+            &reproj_err[2 * i]
+        );
+    }
+
+    for (i = 0; i < p; i++)
+    {
+        compute_zach_weight_error(&w[i], &w_err[i]);
+    }
+}
+
+extern int enzyme_const;
+extern int enzyme_dup;
+extern int enzyme_dupnoneed;
+void __enzyme_autodiff(...) noexcept;
+
+void dcompute_reproj_error(
+    double const* cam,
+    double * dcam,
+    double const* X,
+    double * dX,
+    double const* w,
+    double * wb,
+    double const* feat,
+    double *err,
+    double *derr
+)
+{
+    __enzyme_autodiff(compute_reproj_error,
+            enzyme_dup, cam, dcam,
+            enzyme_dup, X, dX,
+            enzyme_dup, w, wb,
+            enzyme_const, feat,
+            enzyme_dupnoneed, err, derr);
+}
+
+void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr) {
+    __enzyme_autodiff(compute_zach_weight_error,
+            enzyme_dup, w, dw,
+            enzyme_dupnoneed, err, derr);
+}
+
+}
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/lib.rs b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs
new file mode 100644
index 0000000..7efd43f
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs
@@ -0,0 +1,25 @@
+#![feature(autodiff)]
+#![allow(non_snake_case)]
+
+use std::autodiff::autodiff;
+pub mod safe;
+pub mod r#unsafe;
+
+static BA_NCAMPARAMS: usize = 11;
+
+#[no_mangle]
+pub extern "C" fn rust_dcompute_zach_weight_error(
+    w: *const f64,
+    dw: *mut f64,
+    err: *mut f64,
+    derr: *mut f64,
+) {
+    dcompute_zach_weight_error(w, dw, err, derr);
+}
+
+#[autodiff(dcompute_zach_weight_error, Reverse, Duplicated, Duplicated)]
+pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) {
+    let w = unsafe { *w };
+    unsafe { *err = 1. - w * w; }
+}
+
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/main.rs b/enzyme/benchmarks/ReverseMode/ba/src/main.rs
new file mode 100644
index 0000000..13f221b
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/main.rs
@@ -0,0 +1,26 @@
+use bars::{dcompute_reproj_error, dcompute_zach_weight_error};
+fn main() {
+    let cam = [0.0; 11];
+    let mut dcam = [0.0; 11];
+    let x = [0.0; 3];
+    let mut dx = [0.0; 3];
+    let w = [0.0; 1];
+    let mut dw = [0.0; 1];
+    let feat = [0.0; 2];
+    let mut err = [0.0; 2];
+    let mut derr = [0.0; 2];
+    dcompute_reproj_error(
+        &cam as *const [f64;11],
+        &mut dcam as *mut [f64;11],
+        &x as *const [f64;3],
+        &mut dx as *mut [f64;3],
+        &w as *const [f64;1],
+        &mut dw as *mut [f64;1],
+        &feat as *const [f64;2],
+        &mut err as *mut [f64;2],
+        &mut derr as *mut [f64;2],
+    );
+
+    let mut wb = 0.0;
+    dcompute_zach_weight_error(&w as *const f64, &mut dw as *mut f64, &mut err as *mut f64, &mut derr as *mut f64);
+}
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/safe.rs b/enzyme/benchmarks/ReverseMode/ba/src/safe.rs
new file mode 100644
index 0000000..3530c79
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/safe.rs
@@ -0,0 +1,204 @@
+use crate::BA_NCAMPARAMS;
+use crate::compute_zach_weight_error;
+use std::autodiff::autodiff;
+
+fn sqsum(x: &[f64]) -> f64 {
+    x.iter().map(|&v| v * v).sum()
+}
+
+#[inline]
+fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
+    [
+        a[1] * b[2] - a[2] * b[1],
+        a[2] * b[0] - a[0] * b[2],
+        a[0] * b[1] - a[1] * b[0],
+    ]
+}
+
+fn radial_distort(rad_params: &[f64], proj: &mut [f64]) {
+    let rsq = sqsum(proj);
+    let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq;
+    proj[0] = proj[0] * l;
+    proj[1] = proj[1] * l;
+}
+
+fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) {
+    let sqtheta = sqsum(rot);
+    if sqtheta != 0. {
+        let theta = sqtheta.sqrt();
+        let costheta = theta.cos();
+        let sintheta = theta.sin();
+        let theta_inverse = 1. / theta;
+        let mut w = [0.; 3];
+        for i in 0..3 {
+            w[i] = rot[i] * theta_inverse;
+        }
+        let w_cross_pt = cross(&w, &pt);
+        let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta);
+        for i in 0..3 {
+            rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
+        }
+    } else {
+        let rot_cross_pt = cross(&rot, &pt);
+        for i in 0..3 {
+            rotated_pt[i] = pt[i] + rot_cross_pt[i];
+        }
+    }
+}
+
+fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) {
+    let C = &cam[3..6];
+    let mut Xo = [0.; 3];
+    let mut Xcam = [0.; 3];
+
+    Xo[0] = X[0] - C[0];
+    Xo[1] = X[1] - C[1];
+    Xo[2] = X[2] - C[2];
+
+    rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam);
+
+    proj[0] = Xcam[0] / Xcam[2];
+    proj[1] = Xcam[1] / Xcam[2];
+
+    radial_distort(&cam[9..], proj);
+
+    proj[0] = proj[0] * cam[6] + cam[7];
+    proj[1] = proj[1] * cam[6] + cam[8];
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dcompute_reproj_error(
+    cam: *const [f64; 11],
+    dcam: *mut [f64; 11],
+    x: *const [f64; 3],
+    dx: *mut [f64; 3],
+    w: *const [f64; 1],
+    wb: *mut [f64; 1],
+    feat: *const [f64; 2],
+    err: *mut [f64; 2],
+    derr: *mut [f64; 2],
+) {
+    unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)};
+}
+
+#[autodiff(
+    dcompute_reproj_error,
+    Reverse,
+    Duplicated,
+    Duplicated,
+    Duplicated,
+    Const,
+    DuplicatedOnly
+)]
+pub fn compute_reproj_error(
+    cam: *const [f64; 11],
+    x: *const [f64; 3],
+    w: *const [f64; 1],
+    feat: *const [f64; 2],
+    err: *mut [f64; 2],
+) {
+    let cam = unsafe { &*cam };
+    let w = unsafe { *(*w).get_unchecked(0) };
+    let x = unsafe { &*x };
+    let feat = unsafe { &*feat };
+    let err = unsafe { &mut *err };
+    let mut proj = [0.; 2];
+    project(cam, x, &mut proj);
+    err[0] = w * (proj[0] - feat[0]);
+    err[1] = w * (proj[1] - feat[1]);
+}
+
+// n number of cameras
+// m number of points
+// p number of observations
+// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2]
+//            r1, r2, r3 are angle - axis rotation parameters(Rodrigues)
+//            [C1 C2 C3]' is the camera center
+//            f is the focal length in pixels
+//            [u0 v0]' is the principal point
+//            k1, k2 are radial distortion parameters
+// X: 3*m points
+// obs: 2*p observations (pairs cameraIdx, pointIdx)
+// feats: 2*p features (x,y coordinates corresponding to observations)
+// reproj_err: 2*p errors of observations
+// w_err: p weight "error" terms
+fn rust_ba_objective(
+    n: usize,
+    m: usize,
+    p: usize,
+    cams: &[f64],
+    x: &[f64],
+    w: &[f64],
+    obs: &[i32],
+    feats: &[f64],
+    reproj_err: &mut [f64],
+    w_err: &mut [f64],
+) {
+    assert_eq!(cams.len(), n * 11);
+    assert_eq!(x.len(), m * 3);
+    assert_eq!(w.len(), p);
+    assert_eq!(obs.len(), p * 2);
+    assert_eq!(feats.len(), p * 2);
+    assert_eq!(reproj_err.len(), p * 2);
+    assert_eq!(w_err.len(), p);
+
+    for i in 0..p {
+        let cam_idx = obs[i * 2 + 0] as usize;
+        let pt_idx = obs[i * 2 + 1] as usize;
+        let start = cam_idx * BA_NCAMPARAMS;
+        let cam: &[f64; 11] = unsafe {
+            cams[start..]
+                .get_unchecked(..11)
+                .try_into()
+                .unwrap_unchecked()
+        };
+        let x: &[f64; 3] = unsafe {
+            x[pt_idx * 3..]
+                .get_unchecked(..3)
+                .try_into()
+                .unwrap_unchecked()
+        };
+        let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() };
+        let feat: &[f64; 2] = unsafe {
+            feats[i * 2..]
+                .get_unchecked(..2)
+                .try_into()
+                .unwrap_unchecked()
+        };
+        let reproj_err: &mut [f64; 2] = unsafe {
+            reproj_err[i * 2..]
+                .get_unchecked_mut(..2)
+                .try_into()
+                .unwrap_unchecked()
+        };
+        compute_reproj_error(cam, x, w, feat, reproj_err);
+    }
+
+    for i in 0..p {
+        let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) };
+        compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64);
+    }
+}
+
+#[no_mangle]
+extern "C" fn rust2_ba_objective(
+    n: usize,
+    m: usize,
+    p: usize,
+    cams: *const f64,
+    x: *const f64,
+    w: *const f64,
+    obs: *const i32,
+    feats: *const f64,
+    reproj_err: *mut f64,
+    w_err: *mut f64,
+) {
+    let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) };
+    let x = unsafe { std::slice::from_raw_parts(x, m * 3) };
+    let w = unsafe { std::slice::from_raw_parts(w, p) };
+    let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) };
+    let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) };
+    let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) };
+    let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) };
+    rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err);
+}
diff --git a/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs b/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs
new file mode 100644
index 0000000..09f74be
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs
@@ -0,0 +1,140 @@
+use crate::BA_NCAMPARAMS;
+use crate::compute_zach_weight_error;
+use std::autodiff::autodiff;
+
+unsafe fn sqsum(x: *const f64, n: usize) -> f64 {
+    let mut sum = 0.;
+    for i in 0..n {
+        let v = unsafe { *x.add(i) };
+        sum += v * v;
+    }
+    sum
+}
+
+#[inline]
+unsafe fn cross(a: *const f64, b: *const f64, out: *mut f64) {
+    *out.add(0) = *a.add(1) * *b.add(2) - *a.add(2) * *b.add(1);
+    *out.add(1) = *a.add(2) * *b.add(0) - *a.add(0) * *b.add(2);
+    *out.add(2) = *a.add(0) * *b.add(1) - *a.add(1) * *b.add(0);
+}
+
+unsafe fn radial_distort(rad_params: *const f64, proj: *mut f64) {
+    let rsq = sqsum(proj, 2);
+    let l = 1. + *rad_params.add(0) * rsq + *rad_params.add(1) * rsq * rsq;
+    *proj.add(0) = *proj.add(0) * l;
+    *proj.add(1) = *proj.add(1) * l;
+}
+
+unsafe fn rodrigues_rotate_point(rot: *const f64, pt: *const f64, rotated_pt: *mut f64) {
+    let sqtheta = sqsum(rot, 3);
+    if sqtheta != 0. {
+        let theta = sqtheta.sqrt();
+        let costheta = theta.cos();
+        let sintheta = theta.sin();
+        let theta_inverse = 1. / theta;
+        let mut w = [0.; 3];
+        for i in 0..3 {
+            w[i] = *rot.add(i) * theta_inverse;
+        }
+        let mut w_cross_pt = [0.; 3];
+        cross(w.as_ptr(), pt, w_cross_pt.as_mut_ptr());
+        let tmp = (w[0] * *pt.add(0) + w[1] * *pt.add(1) + w[2] * *pt.add(2)) * (1. - costheta);
+        for i in 0..3 {
+            *rotated_pt.add(i) = *pt.add(i) * costheta + w_cross_pt[i] * sintheta + w[i] * tmp;
+        }
+    } else {
+        let mut rot_cross_pt = [0.; 3];
+        cross(rot, pt, rot_cross_pt.as_mut_ptr());
+        for i in 0..3 {
+            *rotated_pt.add(i) = *pt.add(i) + rot_cross_pt[i];
+        }
+    }
+}
+
+unsafe fn project(cam: *const f64, X: *const f64, proj: *mut f64) {
+    let C = cam.add(3);
+    let mut Xo = [0.; 3];
+    let mut Xcam = [0.; 3];
+
+    Xo[0] = *X.add(0) - *C.add(0);
+    Xo[1] = *X.add(1) - *C.add(1);
+    Xo[2] = *X.add(2) - *C.add(2);
+
+    rodrigues_rotate_point(cam, Xo.as_ptr(), Xcam.as_mut_ptr());
+
+    *proj.add(0) = Xcam[0] / Xcam[2];
+    *proj.add(1) = Xcam[1] / Xcam[2];
+
+    radial_distort(cam.add(9), proj);
+    *proj.add(0) = *proj.add(0) * *cam.add(6) + *cam.add(7);
+    *proj.add(1) = *proj.add(1) * *cam.add(6) + *cam.add(8);
+}
+
+#[no_mangle]
+pub unsafe extern "C" fn rust_unsafe_dcompute_reproj_error(
+    cam: *const f64,
+    dcam: *mut f64,
+    x: *const f64,
+    dx: *mut f64,
+    w: *const f64,
+    wb: *mut f64,
+    feat: *const f64,
+    err: *mut f64,
+    derr: *mut f64,
+) {
+    unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)};
+}
+
+
+#[autodiff(
+    dcompute_reproj_error,
+    Reverse,
+    Duplicated,
+    Duplicated,
+    Duplicated,
+    Const,
+    DuplicatedOnly
+)]
+pub unsafe fn compute_reproj_error(
+    cam: *const f64,
+    x: *const f64,
+    w: *const f64,
+    feat: *const f64,
+    err: *mut f64,
+) {
+    let mut proj = [0.; 2];
+    project(cam, x, proj.as_mut_ptr());
+    *err.add(0) = *w * (proj[0] - *feat.add(0));
+    *err.add(1) = *w * (proj[1] - *feat.add(1));
+}
+
+#[no_mangle]
+unsafe extern "C" fn rust2_unsafe_ba_objective(
+    n: usize,
+    m: usize,
+    p: usize,
+    cams: *const f64,
+    x: *const f64,
+    w: *const f64,
+    obs: *const i32,
+    feats: *const f64,
+    reproj_err: *mut f64,
+    w_err: *mut f64,
+) {
+    for i in 0..p {
+        let cam_idx = *obs.add(i * 2 + 0) as usize;
+        let pt_idx = *obs.add(i * 2 + 1) as usize;
+        let start = cam_idx * BA_NCAMPARAMS;
+
+        let cam: *const f64 = cams.add(start);
+        let x: *const f64 = x.add(pt_idx * 3);
+        let w: *const f64 = w.add(i);
+        let feat: *const f64 = feats.add(i * 2);
+        let reproj_err: *mut f64 = reproj_err.add(i * 2);
+        compute_reproj_error(cam, x, w, feat, reproj_err);
+    }
+
+    for i in 0..p {
+        compute_zach_weight_error(w.add(i), w_err.add(i));
+    }
+}
diff --git a/enzyme/benchmarks/ReverseMode/fft/Cargo.lock b/enzyme/benchmarks/ReverseMode/fft/Cargo.lock
new file mode 100644
index 0000000..44847ec
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/Cargo.lock
@@ -0,0 +1,7 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "fft"
+version = "0.1.0"
diff --git a/enzyme/benchmarks/ReverseMode/fft/Cargo.toml b/enzyme/benchmarks/ReverseMode/fft/Cargo.toml
new file mode 100644
index 0000000..cf8862d
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "fft"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+
+[lib]
+crate-type = ["lib"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
diff --git a/enzyme/benchmarks/ReverseMode/fft/Makefile.make b/enzyme/benchmarks/ReverseMode/fft/Makefile.make
index 17ea03a..b9385cd 100644
--- a/enzyme/benchmarks/ReverseMode/fft/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/fft/Makefile.make
@@ -7,6 +7,9 @@
 clean:
 	rm -f *.ll *.o results.txt results.json
 
+$(dir)/benchmarks/ReverseMode/fft/target/release/libfft.a: src/lib.rs Cargo.toml
+	RUSTFLAGS="-Z autodiff=Enable" cargo +enzyme rustc --release --lib --crate-type=staticlib
+
 %-unopt.ll: %.cpp
 	clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
 
@@ -16,7 +19,7 @@
 %-opt.ll: %-raw.ll
 	opt $^ -o $@ -S
 
-fft.o: fft-opt.ll
+fft.o: fft-opt.ll $(dir)/benchmarks/ReverseMode/fft/target/release/libfft.a
 	clang++ $(BENCH) -pthread -O2 $^ -o $@ $(BENCHLINK) -lpthread -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
 	#clang++ $(LOAD) $(BENCH) fft.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o fft.o -lpthread $(BENCHLINK) -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
 
diff --git a/enzyme/benchmarks/ReverseMode/fft/fft.cpp b/enzyme/benchmarks/ReverseMode/fft/fft.cpp
index cf9459b..799b7b1 100644
--- a/enzyme/benchmarks/ReverseMode/fft/fft.cpp
+++ b/enzyme/benchmarks/ReverseMode/fft/fft.cpp
@@ -1,237 +1,368 @@
+#include <adept.h>
+#include <adept_source.h>
+#include <inttypes.h>
 #include <math.h>
 #include <stdio.h>
 #include <stdlib.h>
-#include <sys/time.h>
-#include <stdlib.h>
-#include <math.h>
-#include <inttypes.h>
 #include <string.h>
-#include <adept_source.h>
-#include <adept.h>
+#include <sys/time.h>
 using adept::adouble;
 
-template<typename Return, typename... T>
-Return __enzyme_autodiff(T...);
+template <typename Return, typename... T> Return __enzyme_autodiff(T...);
 
 float tdiff(struct timeval *start, struct timeval *end) {
-  return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec);
+  return (end->tv_sec - start->tv_sec) + 1e-6 * (end->tv_usec - start->tv_usec);
 }
 
 #include "fft.h"
 
-void foobar(double* data, unsigned len) {
+void foobar(double *data, size_t len) {
   fft(data, len);
   ifft(data, len);
 }
 
-void afoobar(aVector& data, unsigned len) {
+void afoobar(aVector &data, size_t len) {
   fft(data, len);
   ifft(data, len);
 }
 
 extern "C" {
-  int enzyme_dupnoneed;
+int enzyme_dupnoneed;
 }
 
-static double foobar_and_gradient(unsigned len) {
-    double *inp = new double[2*len];
-    for(int i=0; i<2*len; i++) inp[i] = 2.0;
-    double *dinp = new double[2*len];
-    for(int i=0; i<2*len; i++) dinp[i] = 1.0;
-    __enzyme_autodiff<void>(foobar, enzyme_dupnoneed, inp, dinp, len);
-    double res = dinp[0];
-    delete[] dinp;
-    delete[] inp;
-    return res;
+extern "C" void rust_unsafe_dfoobar(size_t n, double *data, double *ddata);
+extern "C" void rust_unsafe_foobar(size_t n, double *data);
+extern "C" void rust_dfoobar(size_t n, double *data, double *ddata);
+extern "C" void rust_foobar(size_t n, double *data);
+
+static double rust_unsafe_foobar_and_gradient(size_t len) {
+  double *inp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    inp[i] = 2.0;
+  double *dinp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    dinp[i] = 1.0;
+  rust_unsafe_dfoobar(len, inp, dinp);
+  double res = dinp[0];
+  delete[] dinp;
+  delete[] inp;
+  return res;
 }
 
-static double afoobar_and_gradient(unsigned len) {
-    adept::Stack stack;
-
-    aVector x(2*len);
-    for(int i=0; i<2*len; i++) x(i) = 2.0;
-    stack.new_recording();
-    afoobar(x, len);
-    for(int i=0; i<2*len; i++)
-      x(i).set_gradient(1.0);
-    stack.compute_adjoint();
-
-    double *dinp = new double[2*len];
-    for(int i=0; i<2*len; i++)
-      dinp[i] = x(i).get_gradient();
-    double res = dinp[0];
-    delete[] dinp;
-    return res;
+static double rust_foobar_and_gradient(size_t len) {
+  double *inp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    inp[i] = 2.0;
+  double *dinp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    dinp[i] = 1.0;
+  rust_dfoobar(len, inp, dinp);
+  double res = dinp[0];
+  delete[] dinp;
+  delete[] inp;
+  return res;
 }
 
-
-static double tfoobar_and_gradient(unsigned len) {
-    double *inp = new double[2*len];
-    for(int i=0; i<2*len; i++) inp[i] = 2.0;
-    double *dinp = new double[2*len];
-    for(int i=0; i<2*len; i++) dinp[i] = 1.0;
-    foobar_b(inp, dinp, len);
-    double res = dinp[0];
-    delete[] dinp;
-    delete[] inp;
-    return res;
+__attribute__((noinline)) static double foobar_and_gradient(size_t len) {
+  double *inp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    inp[i] = 2.0;
+  double *dinp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    dinp[i] = 1.0;
+  __enzyme_autodiff<void>(foobar, enzyme_dupnoneed, inp, dinp, len);
+  double res = dinp[0];
+  delete[] dinp;
+  delete[] inp;
+  return res;
 }
 
-static void adept_sincos(double inp, unsigned len) {
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
-  double *x = new double[2*len];
-  for(int i=0; i<2*len; i++) x[i] = 2.0;
-  foobar(x, len);
-  double res = x[0];
-
-  gettimeofday(&end, NULL);
-  printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res);
-  delete[] x;
-  }
-
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
+static double afoobar_and_gradient(size_t len) {
   adept::Stack stack;
 
-  aVector x(2*len);
-  for(int i=0; i<2*len; i++) x[i] = 2.0;
- // stack.new_recording();
+  aVector x(2 * len);
+  for (size_t i = 0; i < 2 * len; i++)
+    x(i) = 2.0;
+  stack.new_recording();
   afoobar(x, len);
-  double res = x(0).value();
+  for (size_t i = 0; i < 2 * len; i++)
+    x(i).set_gradient(1.0);
+  stack.compute_adjoint();
 
-  gettimeofday(&end, NULL);
-  printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res);
+  double *dinp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    dinp[i] = x(i).get_gradient();
+  double res = dinp[0];
+  delete[] dinp;
+  return res;
+}
+
+static double tfoobar_and_gradient(size_t len) {
+  double *inp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    inp[i] = 2.0;
+  double *dinp = new double[2 * len];
+  for (size_t i = 0; i < 2 * len; i++)
+    dinp[i] = 1.0;
+  foobar_b(inp, dinp, len);
+  double res = dinp[0];
+  delete[] dinp;
+  delete[] inp;
+  return res;
+}
+
+static void adept_sincos(double inp, size_t len) {
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+
+    double *x = new double[2 * len];
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    foobar(x, len);
+    double res = x[0];
+
+    gettimeofday(&end, NULL);
+    printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res);
+    delete[] x;
   }
 
   {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
 
-  double res2 = afoobar_and_gradient(len);
+    adept::Stack stack;
 
-  gettimeofday(&end, NULL);
-  printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+    aVector x(2 * len);
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    // stack.new_recording();
+    afoobar(x, len);
+    double res = x(0).value();
+
+    gettimeofday(&end, NULL);
+    printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res);
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+
+    double res2 = afoobar_and_gradient(len);
+
+    gettimeofday(&end, NULL);
+    printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
   }
 }
 
+static void tapenade_sincos(double inp, size_t len) {
 
-static void tapenade_sincos(double inp, unsigned len) {
+  // {
+  // struct timeval start, end;
+  // gettimeofday(&start, NULL);
+
+  // double *x = new double[2*len];
+  // for(size_t i=0; i<2*len; i++) x[i] = 2.0;
+  // foobar(x, len);
+  // double res = x[0];
+
+  // gettimeofday(&end, NULL);
+  // printf("Tapenade real %0.6f res=%f\n", tdiff(&start, &end), res);
+  // delete[] x;
+  // }
+
+  // {
+  // struct timeval start, end;
+  // gettimeofday(&start, NULL);
+
+  // double* x = new double[2*len];
+  // for(size_t i=0; i<2*len; i++) x[i] = 2.0;
+  // foobar(x, len);
+  // double res = x[0];
+
+  // gettimeofday(&end, NULL);
+  // printf("Tapenade forward %0.6f res=%f\n", tdiff(&start, &end), res);
+  // delete[] x;
+  // }
+
+  // {
+  // struct timeval start, end;
+  // gettimeofday(&start, NULL);
+
+  // double res2 = tfoobar_and_gradient(len);
+
+  // gettimeofday(&end, NULL);
+  // printf("Tapenade combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+  // }
+}
+
+static void enzyme_sincos(double inp, size_t len) {
 
   {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
 
-  double *x = new double[2*len];
-  for(int i=0; i<2*len; i++) x[i] = 2.0;
-  foobar(x, len);
-  double res = x[0];
+    double *x = new double[2 * len];
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    foobar(x, len);
+    double res = x[0];
 
-  gettimeofday(&end, NULL);
-  printf("Tapenade real %0.6f res=%f\n", tdiff(&start, &end), res);
-  delete[] x;
+    gettimeofday(&end, NULL);
+    printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res);
+    delete[] x;
   }
 
   {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
 
-  double* x = new double[2*len];
-  for(int i=0; i<2*len; i++) x[i] = 2.0;
-  foobar(x, len);
-  double res = x[0];
+    double *x = new double[2 * len];
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    foobar(x, len);
+    double res = x[0];
 
-  gettimeofday(&end, NULL);
-  printf("Tapenade forward %0.6f res=%f\n", tdiff(&start, &end), res);
-  delete[] x;
+    gettimeofday(&end, NULL);
+    printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res);
+    delete[] x;
   }
 
   {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
 
-  double res2 = tfoobar_and_gradient(len);
+    double res2 = foobar_and_gradient(len);
 
-  gettimeofday(&end, NULL);
-  printf("Tapenade combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+    gettimeofday(&end, NULL);
+    printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
   }
 }
 
-static void enzyme_sincos(double inp, unsigned len) {
+static void enzyme_unsafe_rust_sincos(double inp, size_t len) {
 
   {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
 
-  double *x = new double[2*len];
-  for(int i=0; i<2*len; i++) x[i] = 2.0;
-  foobar(x, len);
-  double res = x[0];
+    double *x = new double[2 * len];
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    rust_unsafe_foobar(len, x);
+    double res = x[0];
 
-  gettimeofday(&end, NULL);
-  printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res);
-  delete[] x;
+    gettimeofday(&end, NULL);
+    printf("Enzyme (unsafe Rust) real %0.6f res=%f\n", tdiff(&start, &end),
+           res);
+    delete[] x;
   }
 
   {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
 
-  double *x = new double[2*len];
-  for(int i=0; i<2*len; i++) x[i] = 2.0;
-  foobar(x, len);
-  double res = x[0];
+    double *x = new double[2 * len];
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    rust_unsafe_foobar(len, x);
+    double res = x[0];
 
-  gettimeofday(&end, NULL);
-  printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res);
-  delete[] x;
+    gettimeofday(&end, NULL);
+    printf("Enzyme (unsafe Rust) forward %0.6f res=%f\n", tdiff(&start, &end),
+           res);
+    delete[] x;
   }
 
   {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
 
-  double res2 = foobar_and_gradient(len);
+    double res2 = rust_unsafe_foobar_and_gradient(len);
 
-  gettimeofday(&end, NULL);
-  printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+    gettimeofday(&end, NULL);
+    printf("Enzyme (unsafe Rust) combined %0.6f res'=%f\n", tdiff(&start, &end),
+           res2);
   }
 }
 
+static void enzyme_rust_sincos(double inp, size_t len) {
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+
+    double *x = new double[2 * len];
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    rust_foobar(len, x);
+    double res = x[0];
+
+    gettimeofday(&end, NULL);
+    printf("Enzyme (Rust) real %0.6f res=%f\n", tdiff(&start, &end), res);
+    delete[] x;
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+
+    double *x = new double[2 * len];
+    for (size_t i = 0; i < 2 * len; i++)
+      x[i] = 2.0;
+    rust_foobar(len, x);
+    double res = x[0];
+
+    gettimeofday(&end, NULL);
+    printf("Enzyme (Rust) forward %0.6f res=%f\n", tdiff(&start, &end), res);
+    delete[] x;
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+
+    double res2 = rust_foobar_and_gradient(len);
+
+    gettimeofday(&end, NULL);
+    printf("Enzyme (Rust) combined %0.6f res'=%f\n", tdiff(&start, &end), res2);
+  }
+}
 
 /* Function to check if x is power of 2*/
-bool isPowerOfTwo (int x)
-{
-    /* First x in the below expression is for the case when x is 0 */
-    return x && (!(x&(x-1)));
+bool isPowerOfTwo(size_t x) {
+  /* First x in the below expression is for the case when x is 0 */
+  return x && (!(x & (x - 1)));
 }
 
-unsigned max(unsigned A, unsigned B){
-  if (A>B) return A;
+size_t max(size_t A, size_t B) {
+  if (A > B)
+    return A;
   return B;
 }
 
-int main(int argc, char** argv) {
+int main(int argc, char **argv) {
 
   if (argc < 2) {
     printf("usage %s n [must be power of 2]\n", argv[0]);
     return 1;
   }
-  unsigned N = atoi(argv[1]);
+  size_t N = atol(argv[1]);
   if (!isPowerOfTwo(N)) {
     printf("usage %s n [must be power of 2]\n", argv[0]);
     return 1;
   }
   double inp = -2.1;
 
-  for(unsigned iters=max(1, N>>5); iters <= N; iters*=2) {
-    printf("iters=%d\n", iters);
+  size_t iters = max(1, N >> 0);
+  for (size_t i = 0; i < 5; i++) {
+    printf("iters=%zu\n", iters);
+#if CPP
     adept_sincos(inp, iters);
     tapenade_sincos(inp, iters);
     enzyme_sincos(inp, iters);
+#else
+    enzyme_rust_sincos(inp, iters);
+    enzyme_unsafe_rust_sincos(inp, iters);
+#endif
   }
 }
diff --git a/enzyme/benchmarks/ReverseMode/fft/fft.h b/enzyme/benchmarks/ReverseMode/fft/fft.h
index 809196b..fad3c7d 100644
--- a/enzyme/benchmarks/ReverseMode/fft/fft.h
+++ b/enzyme/benchmarks/ReverseMode/fft/fft.h
@@ -1,71 +1,75 @@
 #ifndef _fft_h_
 #define _fft_h_
 
-#include <adept_source.h>
 #include <adept.h>
 #include <adept_arrays.h>
+#include <adept_source.h>
 using adept::adouble;
 
 using adept::aVector;
 
-
 /*
   A classy FFT and Inverse FFT C++ class library
 
   Author: Tim Molteno, tim@physics.otago.ac.nz
 
-  Based on the article "A Simple and Efficient FFT Implementation in C++" by Volodymyr Myrnyy
-  with just a simple Inverse FFT modification.
+  Based on the article "A Simple and Efficient FFT Implementation in C++" by
+  Volodymyr Myrnyy with just a simple Inverse FFT modification.
 
   Licensed under the GPL v3.
 */
 
-
 #include <cmath>
 
-inline void swap(double* a, double* b) {
-  double temp=*a;
+inline void swap(double *a, double *b) {
+  double temp = *a;
   *a = *b;
   *b = temp;
 }
 
-static void recursiveApply(double* data, int iSign, unsigned N) {
-  if (N == 1) return;
-  recursiveApply(data, iSign, N/2);
-  recursiveApply(data+N, iSign, N/2);
+static void recursiveApply(double *__restrict data, size_t N, int iSign) {
+  if (N == 1)
+    return;
+  recursiveApply(data, N / 2, iSign);
+  recursiveApply(data + N, N / 2, iSign);
 
-  double wtemp = iSign*sin(M_PI/N);
-  double wpi = -iSign*sin(2*M_PI/N);
-  double wpr = -2.0*wtemp*wtemp;
+  double wtemp = iSign * sin(M_PI / N);
+  double wpi = -iSign * sin(2 * (M_PI / N));
+  double wpr = -2.0 * wtemp * wtemp;
   double wr = 1.0;
   double wi = 0.0;
 
-  for (unsigned i=0; i<N; i+=2) {
-    int iN = i+N;
+  for (size_t i = 0; i < N; i += 2) {
+    size_t iN = i + N;
+    double *__restrict ay = &data[i + 1];
+    double *__restrict ax = &data[i];
+    double *__restrict by = &data[iN + 1];
+    double *__restrict bx = &data[iN];
 
-    double tempr = data[iN]*wr - data[iN+1]*wi;
-    double tempi = data[iN]*wi + data[iN+1]*wr;
+    double tempr = *bx * wr - *by * wi;
+    double tempi = *bx * wi + *by * wr;
 
-    data[iN] = data[i]-tempr;
-    data[iN+1] = data[i+1]-tempi;
-    data[i] += tempr;
-    data[i+1] += tempi;
+    *bx = *ax - tempr;
+    *by = *ay - tempi;
+    *ax += tempr;
+    *ay += tempi;
 
     wtemp = wr;
-    wr += wr*wpr - wi*wpi;
-    wi += wi*wpr + wtemp*wpi;
+    wr = wr * (wpr + 1.) - wi * wpi;
+    wi = wi * (wpr + 1.) + wtemp * wpi;
   }
 }
 
-static void scramble(double* data, unsigned N) {
-  int j=1;
-  for (int i=1; i<2*N; i+=2) {
-    if (j>i) {
-      swap(&data[j-1], &data[i-1]);
+static void scramble(double *data, size_t N) {
+  size_t j = 1;
+  for (size_t ii = 0; ii < N; ii++) {
+    size_t i = 2 * ii + 1;
+    if (j > i) {
+      swap(&data[j - 1], &data[i - 1]);
       swap(&data[j], &data[i]);
     }
-    int m = N;
-    while (m>=2 && j>m) {
+    size_t m = N;
+    while (m >= 2 && j > m) {
       j -= m;
       m >>= 1;
     }
@@ -73,69 +77,71 @@
   }
 }
 
-static void rescale(double* data, unsigned N) {
-  double scale = ((double)1)/N;
-  for (unsigned i=0; i<2*N; i++) {
+static void rescale(double *data, size_t N) {
+  double scale = ((double)1) / N;
+  for (size_t i = 0; i < 2 * N; i++) {
     data[i] *= scale;
   }
 }
 
-static void fft(double* data, unsigned N) {
+static void fft(double *data, size_t N) {
   scramble(data, N);
-  recursiveApply(data,1, N);
+  recursiveApply(data, N, 1);
 }
 
-static void ifft(double* data, unsigned N) {
+static void ifft(double *data, size_t N) {
   scramble(data, N);
-  recursiveApply(data,-1, N);
+  recursiveApply(data, N, -1);
   rescale(data, N);
 }
 
-
-
-inline void swapad(adept::ActiveReference<double> a, adept::ActiveReference<double> b) {
-  adouble temp=a;
+inline void swapad(adept::ActiveReference<double> a,
+                   adept::ActiveReference<double> b) {
+  adouble temp = a;
   a = b;
   b = temp;
 }
 
-static void recursiveApply(aVector data, int iSign, unsigned N) {
-  if (N == 1) return;
-  recursiveApply(data, iSign, N/2);
-  recursiveApply(data(adept::range(N,adept::end)), iSign, N/2);
+static void recursiveApply(aVector data, size_t N, int iSign) {
+  if (N == 1)
+    return;
+  recursiveApply(data, N / 2, iSign);
+  recursiveApply(data(adept::range(N, adept::end)), N / 2, iSign);
 
-  adouble wtemp = iSign*std::sin(M_PI/N);
-  adouble wpi = -iSign*std::sin(2*M_PI/N);
-  adouble wpr = -2.0*wtemp*wtemp;
+  adouble wtemp = iSign * std::sin(M_PI / N);
+  adouble wpi = -iSign * std::sin(2 * (M_PI / N));
+  adouble wpr = -2.0 * wtemp * wtemp;
   adouble wr = 1.0;
   adouble wi = 0.0;
 
-  for (unsigned i=0; i<N; i+=2) {
-    int iN = i+N;
+  for (size_t ii = 0; ii < N / 2; ii++) {
+    size_t i = 2 * ii;
+    size_t iN = i + N;
 
-    adouble tempr = data(iN)*wr - data(iN+1)*wi;
-    adouble tempi = data(iN)*wi + data(iN+1)*wr;
+    adouble tempr = data(iN) * wr - data(iN + 1) * wi;
+    adouble tempi = data(iN) * wi + data(iN + 1) * wr;
 
-    data(iN) = data(i)-tempr;
-    data(iN+1) = data(i+1)-tempi;
+    data(iN) = data(i) - tempr;
+    data(iN + 1) = data(i + 1) - tempi;
     data(i) += tempr;
-    data(i+1) += tempi;
+    data(i + 1) += tempi;
 
     wtemp = wr;
-    wr += wr*wpr - wi*wpi;
-    wi += wi*wpr + wtemp*wpi;
+    wr = wr * (wpr + 1.) - wi * wpi;
+    wi = wi * (wpr + 1.) + wtemp * wpi;
   }
 }
 
-static void scramble(aVector data, unsigned N) {
-  int j=1;
-  for (int i=1; i<2*N; i+=2) {
-    if (j>i) {
-      swapad(data(j-1), data(i-1));
+static void scramble(aVector data, size_t N) {
+  size_t j = 1;
+  for (size_t ii = 0; ii < N; ii++) {
+    size_t i = 2 * ii + 1;
+    if (j > i) {
+      swapad(data(j - 1), data(i - 1));
       swapad(data(j), data(i));
     }
-    int m = N;
-    while (m>=2 && j>m) {
+    size_t m = N;
+    while (m >= 2 && j > m) {
       j -= m;
       m >>= 1;
     }
@@ -143,21 +149,21 @@
   }
 }
 
-static void rescale(aVector data, unsigned N) {
-  adouble scale = ((double)1)/N;
-  for (unsigned i=0; i<2*N; i++) {
+static void rescale(aVector data, size_t N) {
+  adouble scale = ((double)1) / N;
+  for (size_t i = 0; i < 2 * N; i++) {
     data[i] *= scale;
   }
 }
 
-static void fft(aVector data, unsigned N) {
+static void fft(aVector data, size_t N) {
   scramble(data, N);
-  recursiveApply(data,1, N);
+  recursiveApply(data, N, 1);
 }
 
-static void ifft(aVector data, unsigned N) {
+static void ifft(aVector data, size_t N) {
   scramble(data, N);
-  recursiveApply(data,-1, N);
+  recursiveApply(data, N, -1);
   rescale(data, N);
 }
 
@@ -165,260 +171,308 @@
 extern "C" {
 
 /*        Generated by TAPENADE     (INRIA, Ecuador team)
-    Tapenade 3.15 (master) - 15 Apr 2020 11:54
+    Tapenade 3.16 (bugfix_servletAD) -  4 Jan 2024 17:44
 */
 #include <adBuffer.h>
+#include <adStack.h>
+#include <math.h>
 
 /*
-  Differentiation of recursiveApply in reverse (adjoint) mode (with options context):
-   gradient     of useful results: *data
-   with respect to varying inputs: *data
-   Plus diff mem management of: data:in
-*/
-static void recursiveApply_b(double *data, double *datab, int iSign, unsigned 
-        int N) {
-    int arg1;
-    double *arg10;
-    double *arg10b;
-    int arg2;
-    if (N != 1) {
-        arg1 = N/2;
-        arg10b = datab + N;
-        arg10 = data + N;
-        arg2 = N/2;
-        double wtemp = iSign*sin(3.1415926536/N);
-        double wpi = -iSign*sin(2*3.1415926536/N);
-        double wpr = -2.0*wtemp*wtemp;
-        double wr = 1.0;
-        double wi = 0.0;
-        for (int i = 0; i <= N-1; i += 2) {
-            int iN = i + N;
-            double tempr = data[iN]*wr - data[iN+1]*wi;
-            double tempi = data[iN]*wi + data[iN+1]*wr;
-            double tmp;
-            double tmp0;
-            wtemp = wr;
-            pushReal8(wr);
-            wr = wr + (wr*wpr - wi*wpi);
-            pushReal8(wi);
-            wi = wi + (wi*wpr + wtemp*wpi);
-            pushInteger4(iN);
-        }
-        pushPointer8(arg10b);
-        pushInteger4(arg2);
-        pushInteger4(arg1);
-        popInteger4(&arg1);
-        popInteger4(&arg2);
-        popPointer8((void **)&arg10b);
-        for (int i = N-(N-1)%2-1; i >= 0; i -= 2) {
-            int iN;
-            double tempr;
-            double temprb = 0.0;
-            double tempi;
-            double tempib = 0.0;
-            double tmpb;
-            double tmpb0;
-            popInteger4(&iN);
-            tmpb0 = datab[iN + 1];
-            popReal8(&wi);
-            popReal8(&wr);
-            tempib = datab[i + 1] - tmpb0;
-            temprb = datab[i];
-            datab[iN + 1] = 0.0;
-            datab[i + 1] = datab[i + 1] + tmpb0;
-            tmpb = datab[iN];
-            datab[iN] = 0.0;
-            datab[i] = datab[i] + tmpb;
-            temprb = temprb - tmpb;
-            datab[iN + 1] = datab[iN + 1] + wr*tempib - wi*temprb;
-            datab[iN] = datab[iN] + wi*tempib + wr*temprb;
-        }
-        recursiveApply_b(arg10, arg10b, iSign, arg2);
-        recursiveApply_b(data, datab, iSign, arg1);
-    }
-}
-
-static void recursiveApply_nodiff(double *data, int iSign, unsigned int N) {
-    int arg1;
-    double *arg10;
-    int arg2;
-    if (N == 1)
-        return;
-    else {
-        arg1 = N/2;
-        recursiveApply_nodiff(data, iSign, arg1);
-        arg10 = data + N;
-        arg2 = N/2;
-        recursiveApply_nodiff(arg10, iSign, arg2);
-        double wtemp = iSign*sin(3.1415926536/N);
-        double wpi = -iSign*sin(2*3.1415926536/N);
-        double wpr = -2.0*wtemp*wtemp;
-        double wr = 1.0;
-        double wi = 0.0;
-        for (int i = 0; i <= N-1; i += 2) {
-            int iN = i + N;
-            double tempr = data[iN]*wr - data[iN+1]*wi;
-            double tempi = data[iN]*wi + data[iN+1]*wr;
-            data[iN] = data[i] - tempr;
-            data[iN + 1] = data[i + 1] - tempi;
-            data[i] += tempr;
-            data[i + 1] += tempi;
-            wtemp = wr;
-            wr += wr*wpr - wi*wpi;
-            wi += wi*wpr + wtemp*wpi;
-        }
-    }
-}
-
-/*
-  Differentiation of swap in reverse (adjoint) mode (with options context):
+  Differentiation of swap in reverse (adjoint) mode:
    gradient     of useful results: *a *b
    with respect to varying inputs: *a *b
    Plus diff mem management of: a:in b:in
 */
-static void swap_b(double *a, double *ab, double *b, double *bb) {
-    double temp = *a;
-    double tempb = 0.0;
-    tempb = *bb;
-    *bb = *ab;
-    *ab = tempb;
+inline void swap_b(double *a, double *ab, double *b, double *bb) {
+  double temp = *a;
+  double tempb = 0.0;
+  *a = *b;
+  *b = temp;
+  tempb = *bb;
+  *bb = *ab;
+  *ab = tempb;
 }
 
-static void swap_nodiff(double *a, double *b) {
-    double temp = *a;
-    *a = *b;
-    *b = temp;
+inline void swap_c(double *a, double *b) {
+  double temp = *a;
+  *a = *b;
+  *b = temp;
 }
 
-/*
-  Differentiation of scramble in reverse (adjoint) mode (with options context):
-   gradient     of useful results: *data
-   with respect to varying inputs: *data
-   Plus diff mem management of: data:in
-*/
-static void scramble_b(double *data, double *datab, unsigned int N) {
-    int j = 1;
-    int branch;
-    for (int i = 1; i <= 2*N-1; i += 2) {
-        int adCount;
-        if (j > i)
-            pushControl1b(0);
-        else
-            pushControl1b(1);
-        int m = N;
-        adCount = 0;
-        while(m >= 2 && j > m) {
-            pushInteger4(j);
-            j = j - m;
-            m = m >> 1;
-            adCount = adCount + 1;
-        }
-        pushInteger4(adCount);
-        pushInteger4(j);
-        j = j + m;
+static void recursiveApply_c(double *data, int iSign, size_t N) {
+  size_t arg1;
+  double *arg10;
+  size_t arg2;
+  if (N == 1)
+    return;
+  else {
+    arg1 = N / 2;
+    recursiveApply_c(data, iSign, arg1);
+    arg10 = data + N;
+    arg2 = N / 2;
+    recursiveApply_c(arg10, iSign, arg2);
+    double wtemp = iSign * sin(3.14 / N);
+    double wpi = -iSign * sin(2 * 3.14 / N);
+    double wpr = -2.0 * wtemp * wtemp;
+    double wr = 1.0;
+    double wi = 0.0;
+    for (size_t ii = 0; ii < N / 2; ii++) {
+      size_t i = 2 * ii;
+      size_t iN = i + N;
+      double tempr = data[iN] * wr - data[iN + 1] * wi;
+      double tempi = data[iN] * wi + data[iN + 1] * wr;
+      data[iN] = data[i] - tempr;
+      data[iN + 1] = data[i + 1] - tempi;
+      data[i] += tempr;
+      data[i + 1] += tempi;
+      wtemp = wr;
+      wr += wr * wpr - wi * wpi;
+      wi += wi * wpr + wtemp * wpi;
     }
-    for (int i = 2*N-(2*N-2)%2-1; i >= 1; i -= 2) {
-        int m;
-        int adCount;
-        int i0;
-        popInteger4(&j);
-        popInteger4(&adCount);
-        for (i0 = 1; i0 < adCount+1; ++i0)
-            popInteger4(&j);
-        popControl1b(&branch);
-        if (branch == 0) {
-            swap_b(&(data[j]), &(datab[j]), &(data[i]), &(datab[i]));
-            swap_b(&(data[j - 1]), &(datab[j - 1]), &(data[i - 1]), &(datab[i 
-                   - 1]));
-        }
+  }
+}
+
+/*
+  Differentiation of recursiveApply in reverse (adjoint) mode:
+   gradient     of useful results: *data
+   with respect to varying inputs: *data
+   Plus diff mem management of: data:in
+*/
+static void recursiveApply_b(double *data, double *datab, int iSign, size_t N) {
+  size_t arg1;
+  double *arg10;
+  double *arg10b;
+  size_t arg2;
+  int branch;
+  if (N != 1) {
+    arg1 = N / 2;
+    pushReal8(*data);
+    recursiveApply_c(data, iSign, arg1);
+    arg10b = datab + N;
+    arg10 = data + N;
+    arg2 = N / 2;
+    if (arg10) {
+      pushReal8(*arg10);
+      pushControl1b(1);
+    } else
+      pushControl1b(0);
+    recursiveApply_c(arg10, iSign, arg2);
+    double wtemp = iSign * sin(3.14 / N);
+    double wpi = -iSign * sin(2 * 3.14 / N);
+    double wpr = -2.0 * wtemp * wtemp;
+    double wr = 1.0;
+    double wi = 0.0;
+    for (size_t ii = 0; ii < N / 2; ii++) {
+      size_t i = 2 * ii;
+      int iN = i + N;
+      double tempr = data[iN] * wr - data[iN + 1] * wi;
+      double tempi = data[iN] * wi + data[iN + 1] * wr;
+      double temprb;
+      double tempib;
+      double tmp;
+      double tmp0;
+      tmp = data[i] - tempr;
+      data[iN] = tmp;
+      tmp0 = data[i + 1] - tempi;
+      data[iN + 1] = tmp0;
+      data[i] = data[i] + tempr;
+      data[i + 1] = data[i + 1] + tempi;
+      wtemp = wr;
+      pushReal8(wr);
+      wr = wr + (wr * wpr - wi * wpi);
+      pushReal8(wi);
+      wi = wi + (wi * wpr + wtemp * wpi);
+      pushInteger4(iN);
     }
-}
-
-static void scramble_nodiff(double *data, unsigned int N) {
-    int j = 1;
-    for (int i = 1; i <= 2*N-1; i += 2) {
-        if (j > i) {
-            swap_nodiff(&(data[j - 1]), &(data[i - 1]));
-            swap_nodiff(&(data[j]), &(data[i]));
-        }
-        int m = N;
-        while(m >= 2 && j > m) {
-            j -= m;
-            m >>= 1;
-        }
-        j += m;
+    for (size_t i = N - (N - 1) % 2 - 1; i >= 0; i -= 2) {
+      int iN;
+      double tempr;
+      double temprb = 0.0;
+      double tempi;
+      double tempib = 0.0;
+      double tmpb;
+      double tmpb0;
+      popInteger4(&iN);
+      tmpb0 = datab[iN + 1];
+      popReal8(&wi);
+      popReal8(&wr);
+      tempib = datab[i + 1] - tmpb0;
+      temprb = datab[i];
+      datab[iN + 1] = 0.0;
+      datab[i + 1] = datab[i + 1] + tmpb0;
+      tmpb = datab[iN];
+      datab[iN] = 0.0;
+      datab[i] = datab[i] + tmpb;
+      temprb = temprb - tmpb;
+      datab[iN + 1] = datab[iN + 1] + wr * tempib - wi * temprb;
+      datab[iN] = datab[iN] + wi * tempib + wr * temprb;
     }
+    popControl1b(&branch);
+    if (branch == 1)
+      popReal8(arg10);
+    recursiveApply_b(arg10, arg10b, iSign, arg2);
+    popReal8(data);
+    recursiveApply_b(data, datab, iSign, arg1);
+  }
 }
 
 /*
-  Differentiation of rescale in reverse (adjoint) mode (with options context):
+  Differentiation of scramble in reverse (adjoint) mode:
    gradient     of useful results: *data
    with respect to varying inputs: *data
    Plus diff mem management of: data:in
 */
-static void rescale_b(double *data, double *datab, unsigned int N) {
-    double scale = (double)1/N;
-    pushReal8(scale);
-    popReal8(&scale);
-    for (int i = 2*N-1; i > -1; --i)
-        datab[i] = scale*datab[i];
+static void scramble_b(double *data, double *datab, size_t N) {
+  int j = 1;
+  int branch;
+  for (size_t ii = 0; ii < N; ii++) {
+    size_t i = 2 * ii + 1;
+    int adCount;
+    if (j > i) {
+      pushReal8(data[i - 1]);
+      pushReal8(data[j - 1]);
+      swap_c(&(data[j - 1]), &(data[i - 1]));
+      pushReal8(data[i]);
+      pushReal8(data[j]);
+      swap_c(&(data[j]), &(data[i]));
+      pushControl1b(0);
+    } else
+      pushControl1b(1);
+    size_t m = N;
+    adCount = 0;
+    while (m >= 2 && j > m) {
+      pushInteger4(j);
+      j = j - m;
+      m = m >> 1;
+      adCount = adCount + 1;
+    }
+    pushInteger4(adCount);
+    pushInteger4(j);
+    j = j + m;
+  }
+  for (size_t i = 2 * N - (2 * N - 2) % 2 - 1; i >= 1; i -= 2) {
+    size_t m;
+    int adCount;
+    size_t i0;
+    popInteger4(&j);
+    popInteger4(&adCount);
+    for (i0 = 1; i0 < adCount + 1; ++i0)
+      popInteger4(&j);
+    popControl1b(&branch);
+    if (branch == 0) {
+      popReal8(&(data[j]));
+      popReal8(&(data[i]));
+      swap_b(&(data[j]), &(datab[j]), &(data[i]), &(datab[i]));
+      popReal8(&(data[j - 1]));
+      popReal8(&(data[i - 1]));
+      swap_b(&(data[j - 1]), &(datab[j - 1]), &(data[i - 1]), &(datab[i - 1]));
+    }
+  }
 }
 
-static void rescale_nodiff(double *data, unsigned int N) {
-    double scale = (double)1/N;
-    for (int i = 0; i < 2*N; ++i)
-        data[i] *= scale;
+static void scramble_c(double *data, size_t N) {
+  size_t j = 1;
+  for (size_t ii = 0; ii < N; ii++) {
+    size_t i = 2 * ii + 1;
+    if (j > i) {
+      swap_c(&(data[j - 1]), &(data[i - 1]));
+      swap_c(&(data[j]), &(data[i]));
+    }
+    size_t m = N;
+    while (m >= 2 && j > m) {
+      j -= m;
+      m >>= 1;
+    }
+    j += m;
+  }
 }
 
 /*
-  Differentiation of fft in reverse (adjoint) mode (with options context):
+  Differentiation of rescale in reverse (adjoint) mode:
    gradient     of useful results: *data
    with respect to varying inputs: *data
    Plus diff mem management of: data:in
 */
-static void fft_b(double *data, double *datab, unsigned int N) {
-    recursiveApply_b(data, datab, 1, N);
-    scramble_b(data, datab, N);
+static void rescale_b(double *data, double *datab, size_t N) {
+  double scale = (double)1 / N;
+  for (size_t i = 0; i < 2 * N; ++i)
+    data[i] = data[i] * scale;
+  for (size_t i = 2 * N - 1; i > -1; --i)
+    datab[i] = scale * datab[i];
 }
 
-static void fft_nodiff(double *data, unsigned int N) {
-    scramble_nodiff(data, N);
-    recursiveApply_nodiff(data, 1, N);
+static void rescale_c(double *data, size_t N) {
+  double scale = (double)1 / N;
+  for (size_t i = 0; i < 2 * N; ++i)
+    data[i] *= scale;
 }
 
 /*
-  Differentiation of ifft in reverse (adjoint) mode (with options context):
+  Differentiation of fiveft in reverse (adjoint) mode:
    gradient     of useful results: *data
    with respect to varying inputs: *data
    Plus diff mem management of: data:in
 */
-static void ifft_b(double *data, double *datab, unsigned int N) {
-    rescale_b(data, datab, N);
-    recursiveApply_b(data, datab, -1, N);
-    scramble_b(data, datab, N);
+void fiveft_b(double *data, double *datab, size_t N) {
+  pushReal8(*data);
+  scramble_c(data, N);
+  pushReal8(*data);
+  recursiveApply_c(data, 1, N);
+  popReal8(data);
+  recursiveApply_b(data, datab, 1, N);
+  popReal8(data);
+  scramble_b(data, datab, N);
 }
 
-static void ifft_nodiff(double *data, unsigned int N) {
-    scramble_nodiff(data, N);
-    recursiveApply_nodiff(data, -1, N);
-    rescale_nodiff(data, N);
+void fiveft_c(double *data, size_t N) {
+  scramble_c(data, N);
+  recursiveApply_c(data, 1, N);
 }
 
 /*
-  Differentiation of foobar in reverse (adjoint) mode (with options context):
+  Differentiation of ifiveft in reverse (adjoint) mode:
    gradient     of useful results: *data
    with respect to varying inputs: *data
-   RW status of diff variables: *data:in-out
    Plus diff mem management of: data:in
 */
-void foobar_b(double *data, double *datab, unsigned int len) {
-    double chksum = 0.0;
-    int i;
-    ifft_b(data, datab, len);
-    fft_b(data, datab, len);
+void ifiveft_b(double *data, double *datab, size_t N) {
+  pushReal8(*data);
+  scramble_c(data, N);
+  pushReal8(*data);
+  recursiveApply_c(data, -1, N);
+  pushReal8(*data);
+  rescale_c(data, N);
+  popReal8(data);
+  rescale_b(data, datab, N);
+  popReal8(data);
+  recursiveApply_b(data, datab, -1, N);
+  popReal8(data);
+  scramble_b(data, datab, N);
 }
 
+void ifiveft_c(double *data, size_t N) {
+  scramble_c(data, N);
+  recursiveApply_c(data, -1, N);
+  rescale_c(data, N);
 }
 
+/*
+  Differentiation of foobar in reverse (adjoint) mode:
+   gradient     of useful results: *data
+   with respect to varying inputs: *data
+   RW status of diff variables: data:(loc) *data:in-out
+   Plus diff mem management of: data:in
+*/
+void foobar_b(double *data, double *datab, size_t len) {
+  pushReal8(*data);
+  fiveft_c(data, len);
+  pushReal8(*data);
+  ifiveft_c(data, len);
+  popReal8(data);
+  ifiveft_b(data, datab, len);
+  popReal8(data);
+  fiveft_b(data, datab, len);
+}
+}
 
 #endif /* _fft_h_ */
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/lib.rs b/enzyme/benchmarks/ReverseMode/fft/src/lib.rs
new file mode 100644
index 0000000..3b49cb6
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/lib.rs
@@ -0,0 +1,6 @@
+#![feature(slice_swap_unchecked)]
+#![feature(autodiff)]
+#![feature(slice_as_chunks)]
+
+pub mod safe;
+pub mod unsf;
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/main.rs b/enzyme/benchmarks/ReverseMode/fft/src/main.rs
new file mode 100644
index 0000000..5f76ad9
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/main.rs
@@ -0,0 +1,22 @@
+use core::mem;
+use fft::safe;//::dfoobar;
+use fft::unsf;//::dfoobar;
+
+fn main() {
+    let len = 16;
+    let mut data = vec![1.0; 2*len];
+    for i in 0..len {
+        data[i] = 2.0;
+    }
+    let mut data_d = vec![1.0; 2*len];
+
+    //unsafe {safe::rust_dfoobar(len, data.as_mut_ptr(), data_d.as_mut_ptr());}
+    //unsafe {safe::rust_foobar(len, data.as_mut_ptr());}
+    unsafe {unsf::unsafe_dfoobar(len, data.as_mut_ptr(), data_d.as_mut_ptr());}
+    unsafe {unsf::unsafe_foobar(len, data.as_mut_ptr());}
+
+    dbg!(&data_d);
+    dbg!(&data);
+    //mem::forget(data);
+    //mem::forget(data_d);
+}
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/safe.rs b/enzyme/benchmarks/ReverseMode/fft/src/safe.rs
new file mode 100644
index 0000000..cbca5ab
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/safe.rs
@@ -0,0 +1,104 @@
+use std::autodiff::autodiff;
+use std::f64::consts::PI;
+use std::slice;
+
+fn bitreversal_perm<T>(data: &mut [T]) {
+    let len = data.len() / 2;
+    let mut j = 1;
+
+    for i in (1..data.len()).step_by(2) {
+        if j > i {
+            //dbg!(&i, &j);
+            data.swap(j-1, i-1);
+            data.swap(j, i);
+            //unsafe {
+            //    data.swap_unchecked(j - 1, i - 1);
+            //    data.swap_unchecked(j, i);
+            //}
+        }
+
+        let mut m = len;
+        while m >= 2 && j > m {
+            j -= m;
+            m >>= 1;
+        }
+
+        j += m;
+    }
+}
+
+fn radix2(data: &mut [f64], i_sign: i32) {
+    let n = data.len() / 2;
+    if n == 1 {
+        return;
+    }
+
+    let (a, b) = data.split_at_mut(n);
+    // assert_eq!(a.len(), b.len());
+    radix2(a, i_sign);
+    radix2(b, i_sign);
+
+    let wtemp = i_sign as f64 * (PI / n as f64).sin();
+    let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin();
+    let wpr = -2.0 * wtemp * wtemp;
+    let mut wr = 1.0;
+    let mut wi = 0.0;
+
+    let (achunks, _) = a.as_chunks_mut();
+    let (bchunks, _) = b.as_chunks_mut();
+    for ([ax, ay], [bx, by]) in achunks.iter_mut().zip(bchunks.iter_mut()) {
+        let tempr = *bx * wr - *by * wi;
+        let tempi = *bx * wi + *by * wr;
+
+        *bx = *ax - tempr;
+        *by = *ay - tempi;
+        *ax += tempr;
+        *ay += tempi;
+
+        let wtemp_new = wr;
+        wr = wr * (wpr + 1.0) - wi * wpi;
+        wi = wi * (wpr + 1.0) + wtemp_new * wpi;
+    }
+}
+
+fn rescale(data: &mut [f64], scale: usize) {
+    let scale = 1. / scale as f64;
+    for elm in data {
+        *elm *= scale;
+    }
+}
+
+fn fft(data: &mut [f64]) {
+    bitreversal_perm(data);
+    radix2(data, 1);
+}
+
+fn ifft(data: &mut [f64]) {
+    bitreversal_perm(data);
+    radix2(data, -1);
+    rescale(data, data.len() / 2);
+}
+
+#[autodiff(dfoobar, Reverse, DuplicatedOnly)]
+pub fn foobar(data: &mut [f64]) {
+    fft(data);
+    ifft(data);
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) {
+    let (data, ddata) = unsafe {
+        (
+            slice::from_raw_parts_mut(data, n * 2),
+            slice::from_raw_parts_mut(ddata, n * 2),
+        )
+    };
+
+    unsafe { dfoobar(data, ddata) };
+}
+
+#[no_mangle]
+pub extern "C" fn rust_foobar(n: usize, data: *mut f64) {
+    let data = unsafe { slice::from_raw_parts_mut(data, n * 2) };
+    foobar(data);
+}
diff --git a/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs b/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs
new file mode 100644
index 0000000..29c8ceb
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs
@@ -0,0 +1,92 @@
+use std::autodiff::autodiff;
+use std::f64::consts::PI;
+
+unsafe fn bitreversal_perm(data: *mut f64, len: usize) {
+    let mut j = 1;
+
+    for i in (1..2 * len).step_by(2) {
+        if j > i {
+            std::ptr::swap(data.add(j - 1), data.add(i - 1));
+            std::ptr::swap(data.add(j), data.add(i));
+        }
+
+        let mut m = len;
+        while m >= 2 && j > m {
+            j -= m;
+            m >>= 1;
+        }
+
+        j += m;
+    }
+}
+
+unsafe fn radix2(data: *mut f64, n: usize, i_sign: i32) {
+    if n == 1 {
+        return;
+    }
+    radix2(data, n / 2, i_sign);
+    radix2(data.add(n), n / 2, i_sign);
+
+    let wtemp = i_sign as f64 * (PI / n as f64).sin();
+    let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin();
+    let wpr = -2.0 * wtemp * wtemp;
+    let mut wr = 1.0;
+    let mut wi = 0.0;
+
+    for i in (0..n).step_by(2) {
+        let in_n = i + n;
+        let ax = &mut *data.add(i);
+        let ay = &mut *data.add(i + 1);
+        let bx = &mut *data.add(in_n);
+        let by = &mut *data.add(in_n + 1);
+        let tempr = *bx * wr - *by * wi;
+        let tempi = *bx * wi + *by * wr;
+
+        *bx = *ax - tempr;
+        *by = *ay - tempi;
+        *ax += tempr;
+        *ay += tempi;
+
+        let wtemp_new = wr;
+        wr = wr * (wpr + 1.0) - wi * wpi;
+        wi = wi * (wpr + 1.0) + wtemp_new * wpi;
+    }
+}
+
+unsafe fn rescale(data: *mut f64, n: usize) {
+    let scale = 1. / n as f64;
+    for i in 0..2 * n {
+        *data.add(i) = *data.add(i) * scale;
+    }
+}
+
+unsafe fn fft(data: *mut f64, n: usize) {
+    bitreversal_perm(data, n);
+    radix2(data, n, 1);
+}
+
+unsafe fn ifft(data: *mut f64, n: usize) {
+    bitreversal_perm(data, n);
+    radix2(data, n, -1);
+    rescale(data, n);
+}
+
+#[autodiff(unsafe_dfoobar, Reverse, Const, DuplicatedOnly)]
+pub unsafe fn unsafe_foobar(n: usize, data: *mut f64) {
+    fft(data, n);
+    ifft(data, n);
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) {
+    unsafe {
+        unsafe_dfoobar(n, data, ddata);
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_foobar(n: usize, data: *mut f64) {
+    unsafe {
+        unsafe_foobar(n, data);
+    }
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock
new file mode 100644
index 0000000..cfdab95
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock
@@ -0,0 +1,16 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "gmmrs"
+version = "0.1.0"
+dependencies = [
+ "libm",
+]
+
+[[package]]
+name = "libm"
+version = "0.2.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml
new file mode 100644
index 0000000..1ae0273
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml
@@ -0,0 +1,26 @@
+[package]
+name = "gmmrs"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[lib]
+crate-type = ["lib"]
+
+[features]
+libm = ["dep:libm"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+panic = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
+
+[dependencies]
+libm = { version = "0.2.8", optional = true }
diff --git a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make
index 1e8e711..17e22dd 100644
--- a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make
@@ -6,6 +6,10 @@
 
 clean:
 	rm -f *.ll *.o results.txt results.json
+	cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a: src/lib.rs Cargo.toml
+	RUSTFLAGS="-Z autodiff=Enable,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm
 
 %-unopt.ll: %.cpp
 	clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,9 +20,9 @@
 %-opt.ll: %-raw.ll
 	opt $^ -o $@ -S
 
-gmm.o: gmm-opt.ll
+gmm.o: gmm-opt.ll $(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a
 	clang++ -pthread -O2 $^ -o $@ $(BENCHLINK) -lm
 	#clang++ $(LOADCLANG) $(BENCH) gmm.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o gmm.o -lpthread $(BENCHLINK) -lm -L /usr/lib/gcc/x86_64-linux-gnu/11
 
 results.json: gmm.o
-	./$^
+	numactl -C 1 ./$^
diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp
index 8660592..cb7e864 100644
--- a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp
+++ b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp
@@ -13,7 +13,7 @@
  *   typedef struct
  *   {
  *       double gamma;
- *       int m;
+ *       size_t m;
  *   } Wishart;
  *
  *   After Tapenade CLI installing use the next command to generate a file:
@@ -39,9 +39,9 @@
 /* ==================================================================== */
 
 // This throws error on n<1
-double arr_max(int n, double const* x)
+double arr_max(size_t n, double const* x)
 {
-    int i;
+    size_t i;
     double m = x[0];
     for (i = 1; i < n; i++)
     {
@@ -57,9 +57,9 @@
 
 
 // sum of component squares
-double sqnorm(int n, double const* x)
+double sqnorm(size_t n, double const* x)
 {
-    int i;
+    size_t i;
     double res = x[0] * x[0];
     for (i = 1; i < n; i++)
     {
@@ -73,13 +73,13 @@
 
 // out = a - b
 void subtract(
-    int d,
+    size_t d,
     double const* x,
     double const* y,
     double* out
 )
 {
-    int id;
+    size_t id;
     for (id = 0; id < d; id++)
     {
         out[id] = x[id] - y[id];
@@ -87,9 +87,9 @@
 }
 
 
-double log_sum_exp(int n, double const* x)
+double log_sum_exp(size_t n, double const* x)
 {
-    int i;
+    size_t i;
     double mx = arr_max(n, x);
     double semx = 0.0;
 
@@ -105,7 +105,7 @@
 __attribute__((const))
 double log_gamma_distrib(double a, double p)
 {
-    int j;
+    int64_t j;
     double out = 0.25 * p * (p - 1) * log(PI);
 
     for (j = 1; j <= p; j++)
@@ -123,17 +123,17 @@
 /* ======================================================================== */
 
 double log_wishart_prior(
-    int p,
-    int k,
+    size_t p,
+    size_t k,
     Wishart wishart,
     double const* sum_qs,
     double const* Qdiags,
     double const* icf
 )
 {
-    int ik;
-    int n = p + wishart.m + 1;
-    int icf_sz = p * (p + 1) / 2;
+    size_t ik;
+    size_t n = p + wishart.m + 1;
+    size_t icf_sz = p * (p + 1) / 2;
 
     double C = n * p * (log(wishart.gamma) - 0.5 * log(2)) - log_gamma_distrib(0.5 * n, p);
 
@@ -150,15 +150,15 @@
 
 
 void preprocess_qs(
-    int d,
-    int k,
+    size_t d,
+    size_t k,
     double const* icf,
     double* sum_qs,
     double* Qdiags
 )
 {
-    int ik, id;
-    int icf_sz = d * (d + 1) / 2;
+    size_t ik, id;
+    size_t icf_sz = d * (d + 1) / 2;
     for (ik = 0; ik < k; ik++)
     {
         sum_qs[ik] = 0.;
@@ -174,14 +174,14 @@
 
 
 void Qtimesx(
-    int d,
+    size_t d,
     double const* Qdiag,
     double const* ltri, // strictly lower triangular part
     double const* x,
     double* out
 )
 {
-    int i, j;
+    size_t i, j;
     for (i = 0; i < d; i++)
     {
         out[i] = Qdiag[i] * x[i];
@@ -189,10 +189,10 @@
 
     //caching lparams as scev doesn't replicate index calculation
     // todo note changing to strengthened form
-    //int Lparamsidx = 0;
+    //size_t Lparamsidx = 0;
     for (i = 0; i < d; i++)
     {
-    	int Lparamsidx = i*(2*d-i-1)/2;
+    	size_t Lparamsidx = i*(2*d-i-1)/2;
         for (j = i + 1; j < d; j++)
         {
             // and this x
@@ -202,24 +202,15 @@
     }
 }
 
-
-
-void gmm_objective(
-    int d,
-    int k,
-    int n,
-    double const* __restrict alphas,
-    double const* __restrict means,
-    double const* __restrict icf,
-    double const* __restrict x,
-    Wishart wishart,
-    double* __restrict err
-)
-{
-    #define int int64_t
-    int ix, ik;
-    const double CONSTANT = -n * d * 0.5 * log(2 * PI);
-    int icf_sz = d * (d + 1) / 2;
+void gmm_objective_restrict(size_t d, size_t k, size_t n,
+                            double const *__restrict alphas,
+                            double const *__restrict means,
+                            double const *__restrict icf,
+                            double const *__restrict x, Wishart wishart,
+                            double *__restrict err) {
+    int64_t ix, ik;
+    const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
+    int64_t icf_sz = d * (d + 1) / 2;
 
     double* Qdiags = (double*)malloc(d * k * sizeof(double));
     double* sum_qs = (double*)malloc(k * sizeof(double));
@@ -256,7 +247,6 @@
     free(xcentered);
     free(Qxcentered);
     free(main_term);
-    #undef int
 }
 
 extern int enzyme_const;
@@ -265,23 +255,16 @@
 void __enzyme_autodiff(...) noexcept;
 
 // *      tapenade -b -o gmm_tapenade -head "gmm_objective(err)/(alphas means icf)" gmm.c
-void dgmm_objective(int d, int k, int n, const double *alphas, double *
-        alphasb, const double *means, double *meansb, const double *icf,
-        double *icfb, const double *x, Wishart wishart, double *err, double *
-        errb) {
-    __enzyme_autodiff(
-            gmm_objective,
-            enzyme_const, d,
-            enzyme_const, k,
-            enzyme_const, n,
-            enzyme_dup, alphas, alphasb,
-            enzyme_dup, means, meansb,
-            enzyme_dup, icf, icfb,
-            enzyme_const, x,
-            enzyme_const, wishart,
-            enzyme_dupnoneed, err, errb);
+void dgmm_objective_restrict(size_t d, size_t k, size_t n, const double *alphas,
+                             double *alphasb, const double *means,
+                             double *meansb, const double *icf, double *icfb,
+                             const double *x, Wishart wishart, double *err,
+                             double *errb) {
+    __enzyme_autodiff(gmm_objective_restrict, enzyme_const, d, enzyme_const, k,
+                      enzyme_const, n, enzyme_dup, alphas, alphasb, enzyme_dup,
+                      means, meansb, enzyme_dup, icf, icfb, enzyme_const, x,
+                      enzyme_const, wishart, enzyme_dupnoneed, err, errb);
 }
-
 }
 
 
@@ -300,20 +283,19 @@
                                 UTILS
  ==================================================================== */
 // This throws error on n<1
-void arr_max_b(int n, const double *x, double *xb, double arr_maxb) {
-    int i;
+void arr_max_b(size_t n, const double *x, double *xb, double arr_maxb) {
     double m = x[0];
     double mb = 0.0;
     int branch;
     double arr_max;
-    for (i = 1; i < n; ++i)
+    for (int64_t i = 1; i < n; ++i)
         if (m < x[i]) {
             m = x[i];
             pushControl1b(1);
         } else
             pushControl1b(0);
     mb = arr_maxb;
-    for (i = n-1; i > 0; --i) {
+    for (int64_t i = (int64_t)n-1; i > 0; --i) {
         popControl1b(&branch);
         if (branch != 0) {
             xb[i] = xb[i] + mb;
@@ -327,8 +309,8 @@
                                 UTILS
  ==================================================================== */
 // This throws error on n<1
-double arr_max_nodiff(int n, const double *x) {
-    int i;
+double arr_max_nodiff(size_t n, const double *x) {
+    size_t i;
     double m = x[0];
     for (i = 1; i < n; ++i)
         if (m < x[i])
@@ -343,20 +325,19 @@
    Plus diff mem management of: x:in
 */
 // sum of component squares
-void sqnorm_b(int n, const double *x, double *xb, double sqnormb) {
-    int i;
+void sqnorm_b(size_t n, const double *x, double *xb, double sqnormb) {
     double res = x[0]*x[0];
     double resb = 0.0;
     double sqnorm;
     resb = sqnormb;
-    for (i = n-1; i > 0; --i)
+    for (int64_t i = (int64_t)n-1; i > 0; --i)
         xb[i] = xb[i] + 2*x[i]*resb;
     xb[0] = xb[0] + 2*x[0]*resb;
 }
 
 // sum of component squares
-double sqnorm_nodiff(int n, const double *x) {
-    int i;
+double sqnorm_nodiff(size_t n, const double *x) {
+    size_t i;
     double res = x[0]*x[0];
     for (i = 1; i < n; ++i)
         res = res + x[i]*x[i];
@@ -370,18 +351,17 @@
    Plus diff mem management of: out:in y:in
 */
 // out = a - b
-void subtract_b(int d, const double *x, const double *y, double *yb, double *
+void subtract_b(size_t d, const double *x, const double *y, double *yb, double *
         out, double *outb) {
-    int id;
-    for (id = d-1; id > -1; --id) {
+    for (int64_t id = (int64_t)d-1; id > -1; --id) {
         yb[id] = yb[id] - outb[id];
         outb[id] = 0.0;
     }
 }
 
 // out = a - b
-void subtract_nodiff(int d, const double *x, const double *y, double *out) {
-    int id;
+void subtract_nodiff(size_t d, const double *x, const double *y, double *out) {
+    size_t id;
     for (id = 0; id < d; ++id)
         out[id] = x[id] - y[id];
 }
@@ -392,8 +372,7 @@
    with respect to varying inputs: *x
    Plus diff mem management of: x:in
 */
-void log_sum_exp_b(int n, const double *x, double *xb, double log_sum_expb) {
-    int i;
+void log_sum_exp_b(size_t n, const double *x, double *xb, double log_sum_expb) {
     double mx;
     double mxb;
     double tempb;
@@ -401,11 +380,11 @@
     mx = arr_max_nodiff(n, x);
     double semx = 0.0;
     double semxb = 0.0;
-    for (i = 0; i < n; ++i)
+    for (int64_t i = 0; i < n; ++i)
         semx = semx + exp(x[i] - mx);
     semxb = log_sum_expb/semx;
     mxb = log_sum_expb;
-    for (i = n-1; i > -1; --i) {
+    for (int64_t i = (int64_t)n-1; i > -1; --i) {
         tempb = exp(x[i]-mx)*semxb;
         xb[i] = xb[i] + tempb;
         mxb = mxb - tempb;
@@ -413,8 +392,8 @@
     arr_max_b(n, x, xb, mxb);
 }
 
-double log_sum_exp_nodiff(int n, const double *x) {
-    int i;
+double log_sum_exp_nodiff(size_t n, const double *x) {
+    size_t i;
     double mx;
     mx = arr_max_nodiff(n, x);
     double semx = 0.0;
@@ -424,7 +403,7 @@
 }
 
 double log_gamma_distrib_nodiff(double a, double p) {
-    int j;
+    size_t j;
     /* TFIX */
     double out = 0.25*p*(p-1)*log(PI);
     double arg1;
@@ -446,12 +425,12 @@
  ========================================================================
                                 MAIN LOGIC
  ======================================================================== */
-void log_wishart_prior_b(int p, int k, Wishart wishart, const double *sum_qs,
+void log_wishart_prior_b(size_t p, size_t k, Wishart wishart, const double *sum_qs,
         double *sum_qsb, const double *Qdiags, double *Qdiagsb, const double *
         icf, double *icfb, double log_wishart_priorb) {
-    int ik;
-    int n = p + wishart.m + 1;
-    int icf_sz = p*(p+1)/2;
+    int64_t ik;
+    size_t n = p + wishart.m + 1;
+    size_t icf_sz = p*(p+1)/2;
     double C;
     float arg1;
     double result1;
@@ -461,7 +440,7 @@
     for (ik = 0; ik < k; ++ik) {
         double frobenius;
         double result1;
-        int arg1;
+        size_t arg1;
         double result2;
     }
     outb = log_wishart_priorb;
@@ -471,12 +450,12 @@
         sum_qsb[ik] = 0.0;
     for (ik = 0; ik < k * icf_sz; ik++) /* TFIX */
         icfb[ik] = 0.0;
-    for (ik = k-1; ik > -1; --ik) {
+    for (ik = (int64_t)k-1; ik > -1; --ik) {
         double frobenius;
         double frobeniusb;
         double result1;
         double result1b;
-        int arg1;
+        size_t arg1;
         double result2;
         double result2b;
         frobeniusb = wishart.gamma*wishart.gamma*0.5*outb;
@@ -493,11 +472,11 @@
 /* ========================================================================
                                 MAIN LOGIC
  ======================================================================== */
-double log_wishart_prior_nodiff(int p, int k, Wishart wishart, const double *
+double log_wishart_prior_nodiff(size_t p, size_t k, Wishart wishart, const double *
         sum_qs, const double *Qdiags, const double *icf) {
-    int ik;
-    int n = p + wishart.m + 1;
-    int icf_sz = p*(p+1)/2;
+    size_t ik;
+    size_t n = p + wishart.m + 1;
+    size_t icf_sz = p*(p+1)/2;
     double C;
     float arg1;
     double result1;
@@ -508,7 +487,7 @@
     for (ik = 0; ik < k; ++ik) {
         double frobenius;
         double result1;
-        int arg1;
+        size_t arg1;
         double result2;
         result1 = sqnorm_nodiff(p, &(Qdiags[ik*p]));
         arg1 = icf_sz - p;
@@ -526,17 +505,17 @@
    with respect to varying inputs: *icf
    Plus diff mem management of: Qdiags:in sum_qs:in icf:in
 */
-void preprocess_qs_b(int d, int k, const double *icf, double *icfb, double *
+void preprocess_qs_b(size_t d, size_t k, const double *icf, double *icfb, double *
         sum_qs, double *sum_qsb, double *Qdiags, double *Qdiagsb) {
-    int ik, id;
-    int icf_sz = d*(d+1)/2;
+    int64_t ik, id;
+    size_t icf_sz = d*(d+1)/2;
     for (ik = 0; ik < k; ++ik)
         for (id = 0; id < d; ++id) {
             double q = icf[ik*icf_sz + id];
             pushReal8(q);
         }
-    for (ik = k-1; ik > -1; --ik) {
-        for (id = d-1; id > -1; --id) {
+    for (ik = (int64_t)k-1; ik > -1; --ik) {
+        for (id = (int64_t)d-1; id > -1; --id) {
             double q;
             double qb = 0.0;
             popReal8(&q);
@@ -549,13 +528,12 @@
     }
 }
 
-void preprocess_qs_nodiff(int d, int k, const double *icf, double *sum_qs,
+void preprocess_qs_nodiff(size_t d, size_t k, const double *icf, double *sum_qs,
         double *Qdiags) {
-    int ik, id;
-    int icf_sz = d*(d+1)/2;
-    for (ik = 0; ik < k; ++ik) {
+    size_t icf_sz = d*(d+1)/2;
+    for (size_t ik = 0; ik < k; ++ik) {
         sum_qs[ik] = 0.;
-        for (id = 0; id < d; ++id) {
+        for (size_t id = 0; id < d; ++id) {
             double q = icf[ik*icf_sz + id];
             sum_qs[ik] = sum_qs[ik] + q;
             Qdiags[ik*d + id] = exp(q);
@@ -569,41 +547,41 @@
    with respect to varying inputs: *out *Qdiag *x *ltri
    Plus diff mem management of: out:in Qdiag:in x:in ltri:in
 */
-void Qtimesx_b(int d, const double *Qdiag, double *Qdiagb, const double *ltri,
+void Qtimesx_b(size_t d, const double *Qdiag, double *Qdiagb, const double *ltri,
         double *ltrib, const double *x, double *xb, double *out, double *outb)
 {
     // strictly lower triangular part
-    int i, j;
+    int64_t i, j;
     int adFrom;
-    int Lparamsidx = 0;
+    size_t Lparamsidx = 0;
     for (i = 0; i < d; ++i) {
         adFrom = i + 1;
         for (j = adFrom; j < d; ++j)
             Lparamsidx++;
         pushInteger4(adFrom);
     }
-    for (i = d-1; i > -1; --i) {
+    for (i = (int64_t)d-1; i > -1; --i) {
         popInteger4(&adFrom);
-        for (j = d-1; j > adFrom-1; --j) {
+        for (j = (int64_t)d-1; j > adFrom-1; --j) {
             --Lparamsidx;
             ltrib[Lparamsidx] = ltrib[Lparamsidx] + x[i]*outb[j];
             xb[i] = xb[i] + ltri[Lparamsidx]*outb[j];
         }
     }
-    for (i = d-1; i > -1; --i) {
+    for (i = (int64_t)d-1; i > -1; --i) {
         Qdiagb[i] = Qdiagb[i] + x[i]*outb[i];
         xb[i] = xb[i] + Qdiag[i]*outb[i];
         outb[i] = 0.0;
     }
 }
 
-void Qtimesx_nodiff(int d, const double *Qdiag, const double *ltri, const
+void Qtimesx_nodiff(size_t d, const double *Qdiag, const double *ltri, const
         double *x, double *out) {
     // strictly lower triangular part
-    int i, j;
+    size_t i, j;
     for (i = 0; i < d; ++i)
         out[i] = Qdiag[i]*x[i];
-    int Lparamsidx = 0;
+    size_t Lparamsidx = 0;
     for (i = 0; i < d; ++i)
         for (j = i+1; j < d; ++j) {
             out[j] = out[j] + ltri[Lparamsidx]*x[i];
@@ -619,19 +597,19 @@
                 *alphas:out
    Plus diff mem management of: err:in means:in icf:in alphas:in
 */
-void gmm_objective_b(int d, int k, int n, const double *alphas, double *
+void gmm_objective_b(size_t d, size_t k, size_t n, const double *alphas, double *
         alphasb, const double *means, double *meansb, const double *icf,
         double *icfb, const double *x, Wishart wishart, double *err, double *
         errb) {
-    int ix, ik;
+    int64_t ix, ik;
     /* TFIX */
-    const double CONSTANT = -n*d*0.5*log(2*PI);
-    int icf_sz = d*(d+1)/2;
+    const double CONSTANT = -(double)n*d*0.5*log(2*PI);
+    size_t icf_sz = d*(d+1)/2;
     double *Qdiags;
     double *Qdiagsb;
     double result1;
     double result1b;
-    int ii1;
+    size_t ii1;
     Qdiagsb = (double *)malloc(d*k*sizeof(double));
     for (ii1 = 0; ii1 < d*k; ++ii1)
         Qdiagsb[ii1] = 0.0;
@@ -687,10 +665,10 @@
     log_sum_exp_b(k, alphas, alphasb, lse_alphasb);
     for (ii1 = 0; ii1 < d * k; ii1++) /* TFIX */
         meansb[ii1] = 0.0;
-    for (ix = n-1; ix > -1; --ix) {
+    for (ix = (int64_t)n-1; ix > -1; --ix) {
         result1b = slseb;
         log_sum_exp_b(k, &(main_term[0]), &(main_termb[0]), result1b);
-        for (ik = k-1; ik > -1; --ik) {
+        for (ik = (int64_t)k-1; ik > -1; --ik) {
             popReal8(&(main_term[ik]));
             alphasb[ik] = alphasb[ik] + main_termb[ik];
             sum_qsb[ik] = sum_qsb[ik] + main_termb[ik];
@@ -733,32 +711,32 @@
 
     // out = a - b
 template<typename T1, typename T2, typename T3>
-void subtract(int d,
+void subtract(size_t d,
     const T1* const x,
     const T2* const y,
     T3* out)
 {
-    for (int id = 0; id < d; id++)
+    for (size_t id = 0; id < d; id++)
     {
         out[id] = x[id] - y[id];
     }
 }
 
 template<typename T>
-T sqnorm(int n, const T* const x)
+T sqnorm(size_t n, const T* const x)
 {
     T res = x[0] * x[0];
-    for (int i = 1; i < n; i++)
+    for (size_t i = 1; i < n; i++)
         res = res + x[i] * x[i];
     return res;
 }
 
 // This throws error on n<1
 template<typename T>
-T arr_max(int n, const T* const x)
+T arr_max(size_t n, const T* const x)
 {
     T m = x[0];
-    for (int i = 1; i < n; i++)
+    for (size_t i = 1; i < n; i++)
     {
         if (m < x[i])
             m = x[i];
@@ -767,12 +745,12 @@
 }
 
     template<typename T>
-void gmm_objective(int d, int k, int n, const T* const alphas, const T* const means,
+void gmm_objective(size_t d, size_t k, size_t n, const T* const alphas, const T* const means,
     const T* const icf, const double* const x, Wishart wishart, T* err);
 
 // split of the outer loop over points
 template<typename T>
-void gmm_objective_split_inner(int d, int k,
+void gmm_objective_split_inner(size_t d, size_t k,
     const T* const alphas,
     const T* const means,
     const T* const icf,
@@ -781,7 +759,7 @@
     T* err);
 // other terms which are outside the loop
 template<typename T>
-void gmm_objective_split_other(int d, int k, int n,
+void gmm_objective_split_other(size_t d, size_t k, size_t n,
     const T* const alphas,
     const T* const means,
     const T* const icf,
@@ -789,7 +767,7 @@
     T* err);
 
 template<typename T>
-T logsumexp(int n, const T* const x);
+T logsumexp(size_t n, const T* const x);
 
 // p: dim
 // k: number of components
@@ -798,20 +776,20 @@
 // Qdiags: d*k
 // icf: (p*(p+1)/2)*k inverse covariance factors
 template<typename T>
-T log_wishart_prior(int p, int k,
+T log_wishart_prior(size_t p, size_t k,
     Wishart wishart,
     const T* const sum_qs,
     const T* const Qdiags,
     const T* const icf);
 
 template<typename T>
-void preprocess_qs(int d, int k,
+void preprocess_qs(size_t d, size_t k,
     const T* const icf,
     T* sum_qs,
     T* Qdiags);
 
 template<typename T>
-void Qtimesx(int d,
+void Qtimesx(size_t d,
     const T* const Qdiag,
     const T* const ltri, // strictly lower triangular part
     const T* const x,
@@ -822,11 +800,11 @@
 ////////////////////////////////////////////////////////////
 
 template<typename T>
-T logsumexp(int n, const T* const x)
+T logsumexp(size_t n, const T* const x)
 {
     T mx = arr_max(n, x);
     T semx = 0.;
-    for (int i = 0; i < n; i++)
+    for (size_t i = 0; i < n; i++)
     {
         semx = semx + exp(x[i] - mx);
     }
@@ -834,19 +812,19 @@
 }
 
 template<typename T>
-T log_wishart_prior(int p, int k,
+T log_wishart_prior(size_t p, size_t k,
     Wishart wishart,
     const T* const sum_qs,
     const T* const Qdiags,
     const T* const icf)
 {
-    int n = p + wishart.m + 1;
-    int icf_sz = p * (p + 1) / 2;
+    size_t n = p + wishart.m + 1;
+    size_t icf_sz = p * (p + 1) / 2;
 
     double C = n * p * (log(wishart.gamma) - 0.5 * log(2)) - log_gamma_distrib(0.5 * n, p);
 
     T out = 0;
-    for (int ik = 0; ik < k; ik++)
+    for (size_t ik = 0; ik < k; ik++)
     {
         T frobenius = sqnorm(p, &Qdiags[ik * p]) + sqnorm(icf_sz - p, &icf[ik * icf_sz + p]);
         out = out + 0.5 * wishart.gamma * wishart.gamma * (frobenius)
@@ -857,16 +835,16 @@
 }
 
 template<typename T>
-void preprocess_qs(int d, int k,
+void preprocess_qs(size_t d, size_t k,
     const T* const icf,
     T* sum_qs,
     T* Qdiags)
 {
-    int icf_sz = d * (d + 1) / 2;
-    for (int ik = 0; ik < k; ik++)
+    size_t icf_sz = d * (d + 1) / 2;
+    for (size_t ik = 0; ik < k; ik++)
     {
         sum_qs[ik] = 0.;
-        for (int id = 0; id < d; id++)
+        for (size_t id = 0; id < d; id++)
         {
             T q = icf[ik * icf_sz + id];
             sum_qs[ik] = sum_qs[ik] + q;
@@ -876,19 +854,19 @@
 }
 
 template<typename T>
-void Qtimesx(int d,
+void Qtimesx(size_t d,
     const T* const Qdiag,
     const T* const ltri, // strictly lower triangular part
     const T* const x,
     T* out)
 {
-    for (int id = 0; id < d; id++)
+    for (size_t id = 0; id < d; id++)
         out[id] = Qdiag[id] * x[id];
 
-    int Lparamsidx = 0;
-    for (int i = 0; i < d; i++)
+    size_t Lparamsidx = 0;
+    for (size_t i = 0; i < d; i++)
     {
-        for (int j = i + 1; j < d; j++)
+        for (size_t j = i + 1; j < d; j++)
         {
             out[j] = out[j] + ltri[Lparamsidx] * x[i];
             Lparamsidx++;
@@ -897,7 +875,7 @@
 }
 
 template<typename T>
-void gmm_objective(int d, int k, int n,
+void gmm_objective(size_t d, size_t k, size_t n,
     const T* const alphas,
     const T* const means,
     const T* const icf,
@@ -905,8 +883,8 @@
     Wishart wishart,
     T* err)
 {
-    const double CONSTANT = -n * d * 0.5 * log(2 * PI);
-    int icf_sz = d * (d + 1) / 2;
+    const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
+    size_t icf_sz = d * (d + 1) / 2;
 
     vector<T> Qdiags(d * k);
     vector<T> sum_qs(k);
@@ -917,9 +895,9 @@
     preprocess_qs(d, k, icf, &sum_qs[0], &Qdiags[0]);
 
     T slse = 0.;
-    for (int ix = 0; ix < n; ix++)
+    for (size_t ix = 0; ix < n; ix++)
     {
-        for (int ik = 0; ik < k; ik++)
+        for (size_t ik = 0; ik < k; ik++)
         {
             subtract(d, &x[ix * d], &means[ik * d], &xcentered[0]);
             Qtimesx(d, &Qdiags[ik * d], &icf[ik * icf_sz + d], &xcentered[0], &Qxcentered[0]);
@@ -937,7 +915,7 @@
 }
 
 template<typename T>
-void gmm_objective_split_inner(int d, int k,
+void gmm_objective_split_inner(size_t d, size_t k,
     const T* const alphas,
     const T* const means,
     const T* const icf,
@@ -945,39 +923,39 @@
     Wishart wishart,
     T* err)
 {
-    int icf_sz = d * (d + 1) / 2;
+    size_t icf_sz = d * (d + 1) / 2;
 
     T* Ldiag = new T[d];
     T* xcentered = new T[d];
     T* mahal = new T[d];
     T* lse = new T[k];
 
-    for (int ik = 0; ik < k; ik++)
+    for (size_t ik = 0; ik < k; ik++)
     {
-        int icf_off = ik * icf_sz;
+        size_t icf_off = ik * icf_sz;
         T sumlog_Ldiag(0.);
-        for (int id = 0; id < d; id++)
+        for (size_t id = 0; id < d; id++)
         {
             sumlog_Ldiag = sumlog_Ldiag + icf[icf_off + id];
             Ldiag[id] = exp(icf[icf_off + id]);
         }
 
-        for (int id = 0; id < d; id++)
+        for (size_t id = 0; id < d; id++)
         {
             xcentered[id] = x[id] - means[ik * d + id];
             mahal[id] = Ldiag[id] * xcentered[id];
         }
-        int Lparamsidx = d;
-        for (int i = 0; i < d; i++)
+        size_t Lparamsidx = d;
+        for (size_t i = 0; i < d; i++)
         {
-            for (int j = i + 1; j < d; j++)
+            for (size_t j = i + 1; j < d; j++)
             {
                 mahal[j] = mahal[j] + icf[icf_off + Lparamsidx] * xcentered[i];
                 Lparamsidx++;
             }
         }
         T sqsum_mahal(0.);
-        for (int id = 0; id < d; id++)
+        for (size_t id = 0; id < d; id++)
         {
             sqsum_mahal = sqsum_mahal + mahal[id] * mahal[id];
         }
@@ -994,14 +972,14 @@
 }
 
 template<typename T>
-void gmm_objective_split_other(int d, int k, int n,
+void gmm_objective_split_other(size_t d, size_t k, size_t n,
     const T* const alphas,
     const T* const means,
     const T* const icf,
     Wishart wishart,
     T* err)
 {
-    const double CONSTANT = -n * d * 0.5 * log(2 * PI);
+    const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
 
     T lse_alphas = logsumexp(k, alphas);
 
@@ -1015,14 +993,14 @@
 
 };
 
-void adept_dgmm_objective(int d, int k, int n, const double *alphas, double *
+void adept_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double *
         alphasb, const double *means, double *meansb, const double *icf,
         double *icfb, const double *x, Wishart wishart, double *err, double *
         errb) {
 
-  int icf_sz = d*(d + 1) / 2;
-  int Jrows = 1;
-  int Jcols = (k*(d + 1)*(d + 2)) / 2;
+  size_t icf_sz = d*(d + 1) / 2;
+  size_t Jrows = 1;
+  size_t Jcols = (k*(d + 1)*(d + 2)) / 2;
 
   adept::Stack stack;
   adouble *aalphas = new adouble[k];
@@ -1050,3 +1028,5 @@
   delete[] ameans;
   delete[] aicf;
 }
+
+#include "gmm_mayalias.h"
diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm.h b/enzyme/benchmarks/ReverseMode/gmm/gmm.h
index eb189af..5dc5bc0 100644
--- a/enzyme/benchmarks/ReverseMode/gmm/gmm.h
+++ b/enzyme/benchmarks/ReverseMode/gmm/gmm.h
@@ -28,9 +28,9 @@
 // wishart: wishart distribution parameters
 // err: 1 output
 void gmm_objective(
-    int d,
-    int k,
-    int n,
+    size_t d,
+    size_t k,
+    size_t n,
     double const* alphas,
     double const* means,
     double const* icf,
diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h b/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h
new file mode 100644
index 0000000..4bcba4f
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h
@@ -0,0 +1,62 @@
+void gmm_objective(size_t d, size_t k, size_t n, double const *alphas,
+                   double const *means, double const *icf, double const *x,
+                   Wishart wishart, double *err) {
+  size_t ix, ik;
+  const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI);
+  size_t icf_sz = d * (d + 1) / 2;
+
+  double *Qdiags = (double *)malloc(d * k * sizeof(double));
+  double *sum_qs = (double *)malloc(k * sizeof(double));
+  double *xcentered = (double *)malloc(d * sizeof(double));
+  double *Qxcentered = (double *)malloc(d * sizeof(double));
+  double *main_term = (double *)malloc(k * sizeof(double));
+
+  preprocess_qs(d, k, icf, &sum_qs[0], &Qdiags[0]);
+
+  double slse = 0.;
+  for (ix = 0; ix < n; ix++) {
+    for (ik = 0; ik < k; ik++) {
+      subtract(d, &x[ix * d], &means[ik * d], &xcentered[0]);
+      Qtimesx(d, &Qdiags[ik * d], &icf[ik * icf_sz + d], &xcentered[0],
+              &Qxcentered[0]);
+      // two caches for qxcentered at idx 0 and at arbitrary index
+      main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]);
+    }
+
+    // storing cmp for max of main_term
+    // 2 x (0 and arbitrary) storing sub to exp
+    // storing sum for use in log
+    slse = slse + log_sum_exp(k, &main_term[0]);
+  }
+
+  // storing cmp of alphas
+  double lse_alphas = log_sum_exp(k, alphas);
+
+  *err = CONSTANT + slse - n * lse_alphas +
+         log_wishart_prior(d, k, wishart, &sum_qs[0], &Qdiags[0], icf);
+
+  free(Qdiags);
+  free(sum_qs);
+  free(xcentered);
+  free(Qxcentered);
+  free(main_term);
+}
+
+// *      tapenade -b -o gmm_tapenade -head "gmm_objective(err)/(alphas means icf)" gmm.c
+void dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double *
+        alphasb, const double *means, double *meansb, const double *icf,
+        double *icfb, const double *x, Wishart wishart, double *err, double *
+        errb) {
+    __enzyme_autodiff(
+            gmm_objective,
+            enzyme_const, d,
+            enzyme_const, k,
+            enzyme_const, n,
+            enzyme_dup, alphas, alphasb,
+            enzyme_dup, means, meansb,
+            enzyme_dup, icf, icfb,
+            enzyme_const, x,
+            enzyme_const, wishart,
+            enzyme_dupnoneed, err, errb);
+}
+
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs
new file mode 100644
index 0000000..4f9fc53
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs
@@ -0,0 +1,10 @@
+#![feature(autodiff)]
+pub mod safe;
+pub mod r#unsafe;
+
+#[derive(Clone, Copy)]
+#[repr(C)]
+pub struct Wishart {
+    pub gamma: f64,
+    pub m: i32,
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/main.rs b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs
new file mode 100644
index 0000000..e7ebf74
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs
@@ -0,0 +1,24 @@
+#![feature(autodiff)]
+use gmmrs::{Wishart, r#unsafe::dgmm_objective};
+
+fn main() {
+    let d = 2;
+    let k = 2;
+    let n = 2;
+    let alphas = vec![0.5, 0.5];
+    let means = vec![0., 0., 1., 1.];
+    let icf = vec![1., 0., 1.];
+    let x = vec![0., 0., 1., 1.];
+    let wishart = Wishart { gamma: 1., m: 1 };
+    let mut err = 0.;
+    let mut d_alphas = vec![0.; alphas.len()];
+    let mut d_means = vec![0.; means.len()];
+    let mut d_icf = vec![0.; icf.len()];
+    let mut d_x = vec![0.; x.len()];
+    let mut d_err = 0.;
+    let mut err2 = &mut err;
+    let mut d_err2 = &mut d_err;
+    let wishart2 = &wishart;
+    // pass as raw ptr:
+    unsafe {dgmm_objective(d, k, n, alphas.as_ptr(), d_alphas.as_mut_ptr(), means.as_ptr(), d_means.as_mut_ptr(), icf.as_ptr(), d_icf.as_mut_ptr(), x.as_ptr(), wishart2 as *const Wishart, err2 as *mut f64, d_err2 as *mut f64);}
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs b/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs
new file mode 100644
index 0000000..9356b11
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs
@@ -0,0 +1,303 @@
+use crate::Wishart;
+use std::f64::consts::PI;
+use std::autodiff::autodiff;
+
+#[cfg(feature = "libm")]
+use libm::lgamma;
+
+#[cfg(not(feature = "libm"))]
+mod cmath {
+    extern "C" {
+        pub fn lgamma(x: f64) -> f64;
+    }
+}
+#[cfg(not(feature = "libm"))]
+#[inline]
+fn lgamma(x: f64) -> f64 {
+    unsafe { cmath::lgamma(x) }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dgmm_objective(
+    d: i32,
+    k: i32,
+    n: i32,
+    alphas: *const f64,
+    dalphas: *mut f64,
+    means: *const f64,
+    dmeans: *mut f64,
+    icf: *const f64,
+    dicf: *mut f64,
+    x: *const f64,
+    wishart: *const Wishart,
+    err: *mut f64,
+    derr: *mut f64,
+) {
+    let k = k as usize;
+    let n = n as usize;
+    let d = d as usize;
+    let alphas = unsafe { std::slice::from_raw_parts(alphas, k) };
+    let means = unsafe { std::slice::from_raw_parts(means, k * d) };
+    let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) };
+    let x = unsafe { std::slice::from_raw_parts(x, n * d) };
+    let wishart: Wishart = unsafe { *wishart };
+    let mut my_err = unsafe { *err };
+
+    let d_alphas = unsafe { std::slice::from_raw_parts_mut(dalphas, k) };
+    let d_means = unsafe { std::slice::from_raw_parts_mut(dmeans, k * d) };
+    let d_icf = unsafe { std::slice::from_raw_parts_mut(dicf, k * d * (d + 1) / 2) };
+    let mut my_derr = unsafe { *derr };
+    let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) =
+        get_workspace(d, k);
+    let (mut bqdiags, mut bsum_qs, mut bxcentered, mut bqxcentered, mut bmain_term) =
+        get_workspace(d, k);
+
+    unsafe { dgmm_objective(
+        d,
+        k,
+        n,
+        alphas,
+        d_alphas,
+        means,
+        d_means,
+        icf,
+        d_icf,
+        x,
+        wishart.gamma,
+        wishart.m,
+        &mut my_err,
+        &mut my_derr,
+        &mut qdiags,
+        &mut bqdiags,
+        &mut sum_qs,
+        &mut bsum_qs,
+        &mut xcentered,
+        &mut bxcentered,
+        &mut qxcentered,
+        &mut bqxcentered,
+        &mut main_term,
+        &mut bmain_term,
+    )};
+
+    unsafe { *err = my_err };
+    unsafe { *derr = my_derr };
+}
+
+#[no_mangle]
+pub extern "C" fn rust_gmm_objective(
+    d: i32,
+    k: i32,
+    n: i32,
+    alphas: *const f64,
+    means: *const f64,
+    icf: *const f64,
+    x: *const f64,
+    wishart: *const Wishart,
+    err: *mut f64,
+) {
+    let k = k as usize;
+    let n = n as usize;
+    let d = d as usize;
+    let alphas = unsafe { std::slice::from_raw_parts(alphas, k) };
+    let means = unsafe { std::slice::from_raw_parts(means, k * d) };
+    let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) };
+    let x = unsafe { std::slice::from_raw_parts(x, n * d) };
+    let wishart: Wishart = unsafe { *wishart };
+    let mut my_err = unsafe { *err };
+    let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) =
+        get_workspace(d, k);
+    gmm_objective(
+        d,
+        k,
+        n,
+        alphas,
+        means,
+        icf,
+        x,
+        wishart.gamma,
+        wishart.m,
+        &mut my_err,
+        &mut qdiags,
+        &mut sum_qs,
+        &mut xcentered,
+        &mut qxcentered,
+        &mut main_term,
+    );
+    unsafe { *err = my_err };
+}
+
+fn get_workspace(d: usize, k: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
+    let qdiags = vec![0.; d * k];
+    let sum_qs = vec![0.; k];
+    let xcentered = vec![0.; d];
+    let qxcentered = vec![0.; d];
+    let main_term = vec![0.; k];
+    (qdiags, sum_qs, xcentered, qxcentered, main_term)
+}
+
+#[autodiff(
+    dgmm_objective,
+    Reverse,
+    Const,
+    Const,
+    Const,
+    Duplicated,
+    Duplicated,
+    Duplicated,
+    Const,
+    Const,
+    Const,
+    DuplicatedOnly,
+    Duplicated,
+    Duplicated,
+    Duplicated,
+    Duplicated,
+    Duplicated
+)]
+pub fn gmm_objective(
+    d: usize,
+    k: usize,
+    n: usize,
+    alphas: &[f64],
+    means: &[f64],
+    icf: &[f64],
+    x: &[f64],
+    gamma: f64,
+    m: i32,
+    err: &mut f64,
+    qdiags: &mut [f64],
+    sum_qs: &mut [f64],
+    xcentered: &mut [f64],
+    qxcentered: &mut [f64],
+    main_term: &mut [f64],
+) {
+    let wishart: Wishart = Wishart { gamma, m };
+    let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln();
+    let icf_sz = d * (d + 1) / 2;
+
+    // Let the compiler know sizes so it can eliminate bounds checks
+    assert_eq!(qdiags.len(), d * k);
+    assert_eq!(sum_qs.len(), k);
+    assert_eq!(xcentered.len(), d);
+    assert_eq!(qxcentered.len(), d);
+    assert_eq!(main_term.len(), k);
+
+    preprocess_qs(d, k, icf, sum_qs, qdiags);
+
+    let mut slse = 0.;
+    for ix in 0..n {
+        for ik in 0..k {
+            subtract(
+                d,
+                &x[ix as usize * d as usize..],
+                &means[ik as usize * d as usize..],
+                xcentered,
+            );
+            qtimesx(
+                d,
+                &qdiags[ik as usize * d as usize..],
+                &icf[ik as usize * icf_sz as usize + d as usize..],
+                &*xcentered,
+                qxcentered,
+            );
+            main_term[ik as usize] =
+                alphas[ik as usize] + sum_qs[ik as usize] - 0.5 * sqnorm(&*qxcentered);
+        }
+
+        slse = slse + log_sum_exp(k, &main_term);
+    }
+
+    let lse_alphas = log_sum_exp(k, alphas);
+
+    *err = constant + slse - n as f64 * lse_alphas
+        + log_wishart_prior(d, k, wishart, &sum_qs, &*qdiags, icf);
+}
+
+fn arr_max(n: usize, x: &[f64]) -> f64 {
+    let mut max = f64::NEG_INFINITY;
+    for i in 0..n {
+        if max < x[i] {
+            max = x[i];
+        }
+    }
+    max
+}
+
+fn preprocess_qs(d: usize, k: usize, icf: &[f64], sum_qs: &mut [f64], qdiags: &mut [f64]) {
+    let icf_sz = d * (d + 1) / 2;
+    for ik in 0..k {
+        sum_qs[ik as usize] = 0.;
+        for id in 0..d {
+            let q = icf[ik as usize * icf_sz as usize + id as usize];
+            sum_qs[ik as usize] = sum_qs[ik as usize] + q;
+            qdiags[ik as usize * d as usize + id as usize] = q.exp();
+        }
+    }
+}
+fn subtract(d: usize, x: &[f64], y: &[f64], out: &mut [f64]) {
+    assert!(x.len() >= d);
+    assert!(y.len() >= d);
+    assert!(out.len() >= d);
+    for i in 0..d {
+        out[i] = x[i] - y[i];
+    }
+}
+
+fn qtimesx(d: usize, q_diag: &[f64], ltri: &[f64], x: &[f64], out: &mut [f64]) {
+    assert!(out.len() >= d);
+    assert!(q_diag.len() >= d);
+    assert!(x.len() >= d);
+    for i in 0..d {
+        out[i] = q_diag[i] * x[i];
+    }
+
+    for i in 0..d {
+        let mut lparamsidx = i * (2 * d - i - 1) / 2;
+        for j in i + 1..d {
+            out[j] = out[j] + ltri[lparamsidx] * x[i];
+            lparamsidx += 1;
+        }
+    }
+}
+
+fn log_sum_exp(n: usize, x: &[f64]) -> f64 {
+    let mx = arr_max(n, x);
+    let semx: f64 = x.iter().map(|x| (x - mx).exp()).sum();
+    semx.ln() + mx
+}
+fn log_gamma_distrib(a: f64, p: f64) -> f64 {
+    0.25 * p * (p - 1.) * PI.ln()
+        + (1..=p as usize)
+            .map(|j| lgamma(a + 0.5 * (1. - j as f64)))
+            .sum::<f64>()
+}
+
+fn log_wishart_prior(
+    p: usize,
+    k: usize,
+    wishart: Wishart,
+    sum_qs: &[f64],
+    qdiags: &[f64],
+    icf: &[f64],
+) -> f64 {
+    let n = p + wishart.m as usize + 1;
+    let icf_sz = p * (p + 1) / 2;
+
+    let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln())
+        - log_gamma_distrib(0.5 * n as f64, p as f64);
+
+    let out = (0..k)
+        .map(|ik| {
+            let frobenius = sqnorm(&qdiags[ik * p as usize..][..p])
+                + sqnorm(&icf[ik * icf_sz as usize + p as usize..][..icf_sz - p]);
+            0.5 * wishart.gamma * wishart.gamma * (frobenius)
+                - (wishart.m as f64) * sum_qs[ik as usize]
+        })
+        .sum::<f64>();
+
+    out - k as f64 * c
+}
+
+fn sqnorm(x: &[f64]) -> f64 {
+    x.iter().map(|x| x * x).sum()
+}
diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs b/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs
new file mode 100644
index 0000000..aa91938
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs
@@ -0,0 +1,148 @@
+use std::f64::consts::PI;
+use crate::Wishart;
+use std::autodiff::autodiff;
+
+#[cfg(feature = "libm")]
+use libm::lgamma;
+
+#[cfg(not(feature = "libm"))]
+mod cmath {
+    extern "C" {
+        pub fn lgamma(x: f64) -> f64;
+    }
+}
+#[cfg(not(feature = "libm"))]
+#[inline]
+fn lgamma(x: f64) -> f64 {
+    unsafe { cmath::lgamma(x) }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_dgmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, dalphas: *mut f64, means: *const f64, dmeans: *mut f64, icf: *const f64, dicf: *mut f64, x: *const f64, wishart: *const Wishart, err: *mut f64, derr: *mut f64) {
+    let k = k as usize;
+    let n = n as usize;
+    let d = d as usize;
+    unsafe { dgmm_objective(d, k, n, alphas, dalphas, means, dmeans, icf, dicf, x, wishart, err, derr); }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_gmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) {
+    let k = k as usize;
+    let n = n as usize;
+    let d = d as usize;
+    unsafe {gmm_objective(d, k, n, alphas, means, icf, x, wishart, err); }
+}
+
+//#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, Duplicated)]
+//pub unsafe fn gmm_objective(d: usize, k: usize, n: usize, alphas: &[f64], means: &[f64], icf: &[f64], x: &[f64], gamma: f64, m: i32, err: &mut f64) {
+//    gmm_objective(d, k, n, alphas, means, icf, x, wishart, &mut my_err);
+//}
+
+#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, DuplicatedOnly)]
+pub unsafe fn gmm_objective(d: usize, k: usize, n: usize, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) {
+    let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln();
+    let icf_sz = d * (d + 1) / 2;
+    let mut qdiags = vec![0.; d * k];
+    let mut sum_qs = vec![0.; k];
+    let mut xcentered = vec![0.; d];
+    let mut qxcentered = vec![0.; d];
+    let mut main_term = vec![0.; k];
+
+    preprocess_qs(d, k, icf, sum_qs.as_mut_ptr(), qdiags.as_mut_ptr());
+
+    let mut slse = 0.;
+    for ix in 0..n {
+        for ik in 0..k {
+            subtract(d, x.add(ix * d), means.add(ik * d), xcentered.as_mut_ptr());
+            qtimesx(d, qdiags.as_mut_ptr().add(ik * d), icf.add(ik * icf_sz + d), xcentered.as_ptr(), qxcentered.as_mut_ptr());
+            main_term[ik] = *alphas.add(ik) + sum_qs[ik] - 0.5 * sqnorm(d, qxcentered.as_ptr());
+            //main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]);
+        }
+
+        slse = slse + log_sum_exp(k, main_term.as_ptr());
+    }
+
+    let lse_alphas = log_sum_exp(k, alphas);
+
+    *err = constant + slse - n as f64 * lse_alphas + log_wishart_prior(d, k, *wishart, sum_qs.as_ptr(), qdiags.as_ptr(), icf);
+}
+
+unsafe fn arr_max(n: usize, x: *const f64) -> f64 {
+    let mut max = f64::NEG_INFINITY;
+    for i in 0..n {
+        if max < *x.add(i) {
+            max = *x.add(i);
+        }
+    }
+    max
+}
+
+unsafe fn preprocess_qs(d: usize, k: usize, icf: *const f64, sum_qs: *mut f64, qdiags: *mut f64) {
+    let icf_sz = d * (d + 1) / 2;
+    for ik in 0..k {
+        *sum_qs.add(ik) = 0.;
+        for id in 0..d {
+            let q = *icf.add(ik * icf_sz + id);
+            *sum_qs.add(ik) = *sum_qs.add(ik) + q;
+            *qdiags.add(ik * d + id) = q.exp();
+        }
+    }
+}
+
+unsafe fn subtract(d: usize, x: *const f64, y: *const f64, out: *mut f64) {
+    for i in 0..d {
+        *out.add(i) = *x.add(i) - *y.add(i);
+    }
+}
+
+unsafe fn qtimesx(d: usize, q_diag: *const f64, ltri: *const f64, x: *const f64, out: *mut f64) {
+    for i in 0..d {
+        *out.add(i) = *q_diag.add(i) * *x.add(i);
+    }
+
+    for i in 0..d {
+        let mut lparamsidx = i*(2*d-i-1)/2;
+        for j in i + 1..d {
+            *out.add(j) = *out.add(j) + *ltri.add(lparamsidx) * *x.add(i);
+            lparamsidx += 1;
+        }
+    }
+}
+
+unsafe fn log_sum_exp(n: usize, x: *const f64) -> f64 {
+    let mx = arr_max(n, x);
+    let mut semx: f64 = 0.0;
+
+    for i in 0..n {
+        semx = semx + (*x.add(i) - mx).exp();
+    }
+    semx.ln() + mx
+}
+
+fn log_gamma_distrib(a: f64, p: f64) -> f64 {
+    0.25 * p * (p - 1.) * PI.ln() + (1..=p as usize).map(|j| lgamma(a + 0.5 * (1. - j as f64))).sum::<f64>()
+}
+
+unsafe fn log_wishart_prior(p: usize, k: usize, wishart: Wishart, sum_qs: *const f64, qdiags: *const f64, icf: *const f64) -> f64 {
+    let n = p + wishart.m as usize + 1;
+    let icf_sz = p * (p + 1) / 2;
+
+    let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) - log_gamma_distrib(0.5 * n as f64, p as f64);
+
+    let mut out = 0.;
+
+    for ik in 0..k {
+        let frobenius = sqnorm(p, qdiags.add(ik * p)) + sqnorm(icf_sz - p, icf.add(ik * icf_sz + p));
+        out = out + 0.5 * wishart.gamma * wishart.gamma * (frobenius) - wishart.m as f64 * *sum_qs.add(ik);
+    }
+
+    out - k as f64 * c
+}
+
+unsafe fn sqnorm(n: usize, x: *const f64) -> f64 {
+    let mut sum = 0.;
+    for i in 0..n {
+        sum += *x.add(i) * *x.add(i);
+    }
+    sum
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock b/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock
new file mode 100644
index 0000000..270bf43
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock
@@ -0,0 +1,7 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "lstm"
+version = "0.1.0"
diff --git a/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml b/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml
new file mode 100644
index 0000000..d28f845
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "lstm"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+
+[lib]
+crate-type = ["lib"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
diff --git a/enzyme/benchmarks/ReverseMode/lstm/Makefile.make b/enzyme/benchmarks/ReverseMode/lstm/Makefile.make
index 276c5df..1388a54 100644
--- a/enzyme/benchmarks/ReverseMode/lstm/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/lstm/Makefile.make
@@ -6,6 +6,10 @@
 
 clean:
 	rm -f *.ll *.o results.txt results.json
+	cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a: src/lib.rs Cargo.toml
+	RUSTFLAGS="-Z autodiff=Enable,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib
 
 %-unopt.ll: %.cpp
 	clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,8 +20,8 @@
 %-opt.ll: %-raw.ll
 	opt $^ -o $@ -S
 
-lstm.o: lstm-opt.ll
+lstm.o: lstm-opt.ll $(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a
 	clang++ -pthread -O2 $^ -o $@ $(BENCHLINK) -lm
 
 results.json: lstm.o
-	./$^
+	numactl -C 1 ./$^
diff --git a/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp b/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp
index dbbc992..ade0b22 100644
--- a/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp
+++ b/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp
@@ -50,15 +50,10 @@
 
 // LSTM OBJECTIVE
 // The LSTM model
-void lstm_model(
-    int hsize,
-    double const* __restrict weight,
-    double const* __restrict bias,
-    double* __restrict hidden,
-    double* __restrict cell,
-    double const* __restrict input
-)
-{
+void lstm_model_restrict(int hsize, double const *__restrict weight,
+                         double const *__restrict bias,
+                         double *__restrict hidden, double *__restrict cell,
+                         double const *__restrict input) {
     // TODO NOTE THIS
     //__builtin_assume(hsize > 0);
 
@@ -94,16 +89,9 @@
 }
 
 // Predict LSTM output given an input
-void lstm_predict(
-    int l,
-    int b,
-    double const* __restrict w,
-    double const* __restrict w2,
-    double* __restrict s,
-    double const* __restrict x,
-    double* __restrict x2
-)
-{
+void lstm_predict_restrict(int l, int b, double const *__restrict w,
+                           double const *__restrict w2, double *__restrict s,
+                           double const *__restrict x, double *__restrict x2) {
     int i;
     for (i = 0; i < b; i++)
     {
@@ -113,7 +101,8 @@
     double* xp = x2;
     for (i = 0; i <= 2 * l * b - 1; i += 2 * b)
     {
-        lstm_model(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]), &(s[i + b]), xp);
+        lstm_model_restrict(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]),
+                            &(s[i + b]), xp);
         xp = &(s[i]);
     }
 
@@ -124,17 +113,12 @@
 }
 
 // LSTM objective (loss function)
-void lstm_objective(
-    int l,
-    int c,
-    int b,
-    double const* __restrict main_params,
-    double const* __restrict extra_params,
-    double* __restrict state,
-    double const* __restrict sequence,
-    double* __restrict loss
-)
-{
+void cxx_restrict_lstm_objective(int l, int c, int b,
+                                 double const *__restrict main_params,
+                                 double const *__restrict extra_params,
+                                 double *__restrict state,
+                                 double const *__restrict sequence,
+                                 double *__restrict loss) {
     int i, t;
     double total = 0.0;
     int count = 0;
@@ -147,7 +131,8 @@
     __builtin_assume(b>0);
     for (t = 0; t <= (c - 1) * b - 1; t += b)
     {
-        lstm_predict(l, b, main_params, extra_params, state, input, ypred);
+        lstm_predict_restrict(l, b, main_params, extra_params, state, input,
+                              ypred);
         lse = logsumexp(ypred, b);
         for (i = 0; i < b; i++)
         {
@@ -177,32 +162,17 @@
 
 // *      tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c
 
-void dlstm_objective(
-    int l,
-    int c,
-    int b,
-    double const* main_params,
-    double* dmain_params,
-    double const* extra_params,
-    double* dextra_params,
-    double* state,
-    double const* sequence,
-    double* loss,
-    double* dloss
-)
-{
-    __enzyme_autodiff(lstm_objective,
-        enzyme_const, l,
-        enzyme_const, c,
-        enzyme_const, b,
-        enzyme_dup, main_params, dmain_params,
-        enzyme_dup, extra_params, dextra_params,
-        enzyme_const, state,
-        enzyme_const, sequence,
-        enzyme_dupnoneed, loss, dloss
-    );
+void dlstm_objective_restrict(int l, int c, int b, double const *main_params,
+                              double *dmain_params, double const *extra_params,
+                              double *dextra_params, double *state,
+                              double const *sequence, double *loss,
+                              double *dloss) {
+    __enzyme_autodiff(cxx_restrict_lstm_objective, enzyme_const, l,
+                      enzyme_const, c, enzyme_const, b, enzyme_dup, main_params,
+                      dmain_params, enzyme_dup, extra_params, dextra_params,
+                      enzyme_const, state, enzyme_const, sequence,
+                      enzyme_dupnoneed, loss, dloss);
 }
-
 }
 
 
@@ -728,3 +698,5 @@
 
 }
 #endif
+
+#include "lstm_mayalias.h"
diff --git a/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h b/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h
new file mode 100644
index 0000000..06401ff
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h
@@ -0,0 +1,160 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+/*
+ *   File "lstm_b_tapenade_generated.c" is generated by Tapenade 3.14 (r7259) from this file.
+ *   To reproduce such a generation you can use Tapenade CLI
+ *   (can be downloaded from http://www-sop.inria.fr/tropics/tapenade/downloading.html)
+ *
+ *   After installing use the next command to generate a file:
+ *
+ *      tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c
+ *
+ *   This will produce a file "lstm_tapenade_b.c" which content will be the same as the content of the file "lstm_b_tapenade_generated.c",
+ *   except one-line header. Moreover a log-file "lstm_tapenade_b.msg" will be produced.
+ *
+ *   NOTE: the code in "lstm_b_tapenade_generated.c" is wrong and won't work.
+ *         REPAIRED SOURCE IS STORED IN THE FILE "lstm_b.c".
+ *         You can either use diff tool or read "lstm_b.c" header to figure out what changes was performed to fix the code.
+ *
+ *   NOTE: you can also use Tapenade web server (http://tapenade.inria.fr:8080/tapenade/index.jsp)
+ *         for generating but the result can be slightly different.
+ */
+
+// #include "../adbench/lstm.h"
+
+extern "C" {
+// #include "lstm.h"
+
+// UTILS
+// Sigmoid on scalar
+// double sigmoid(double x)
+//{
+//    return 1.0 / (1.0 + exp(-x));
+//}
+//
+//// log(sum(exp(x), 2))
+// double logsumexp(double const* vect, int sz)
+//{
+//     double sum = 0.0;
+//     int i;
+//
+//     for (i = 0; i < sz; i++)
+//     {
+//         sum += exp(vect[i]);
+//     }
+//
+//     sum += 2;
+//     return log(sum);
+// }
+
+// LSTM OBJECTIVE
+// The LSTM model
+void lstm_model(int hsize, double const *weight, double const *bias,
+                double *hidden, double *cell, double const *input) {
+  // TODO NOTE THIS
+  //__builtin_assume(hsize > 0);
+
+  double *gates = (double *)malloc(4 * hsize * sizeof(double));
+  double *forget = &(gates[0]);
+  double *ingate = &(gates[hsize]);
+  double *outgate = &(gates[2 * hsize]);
+  double *change = &(gates[3 * hsize]);
+
+  int i;
+  // caching input
+  // hidden (needed)
+  for (i = 0; i < hsize; i++) {
+    forget[i] = sigmoid(input[i] * weight[i] + bias[i]);
+    ingate[i] = sigmoid(hidden[i] * weight[hsize + i] + bias[hsize + i]);
+    outgate[i] =
+        sigmoid(input[i] * weight[2 * hsize + i] + bias[2 * hsize + i]);
+    change[i] = tanh(hidden[i] * weight[3 * hsize + i] + bias[3 * hsize + i]);
+  }
+
+  // caching cell (needed)
+  for (i = 0; i < hsize; i++) {
+    cell[i] = cell[i] * forget[i] + ingate[i] * change[i];
+  }
+
+  for (i = 0; i < hsize; i++) {
+    hidden[i] = outgate[i] * tanh(cell[i]);
+  }
+
+  free(gates);
+}
+
+// Predict LSTM output given an input
+void lstm_predict(int l, int b, double const *w, double const *w2, double *s,
+                  double const *x, double *x2) {
+  int i;
+  for (i = 0; i < b; i++) {
+    x2[i] = x[i] * w2[i];
+  }
+
+  double *xp = x2;
+  for (i = 0; i <= 2 * l * b - 1; i += 2 * b) {
+    lstm_model(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]), &(s[i + b]), xp);
+    xp = &(s[i]);
+  }
+
+  for (i = 0; i < b; i++) {
+    x2[i] = xp[i] * w2[b + i] + w2[2 * b + i];
+  }
+}
+
+// LSTM objective (loss function)
+void cxx_mayalias_lstm_objective(int l, int c, int b, double const *main_params,
+                                 double const *extra_params, double *state,
+                                 double const *sequence, double *loss) {
+  int i, t;
+  double total = 0.0;
+  int count = 0;
+  const double *input = &(sequence[0]);
+  double *ypred = (double *)malloc(b * sizeof(double));
+  double *ynorm = (double *)malloc(b * sizeof(double));
+  const double *ygold;
+  double lse;
+
+  __builtin_assume(b > 0);
+  for (t = 0; t <= (c - 1) * b - 1; t += b) {
+    lstm_predict(l, b, main_params, extra_params, state, input, ypred);
+    lse = logsumexp(ypred, b);
+    for (i = 0; i < b; i++) {
+      ynorm[i] = ypred[i] - lse;
+    }
+
+    ygold = &(sequence[t + b]);
+    for (i = 0; i < b; i++) {
+      total += ygold[i] * ynorm[i];
+    }
+
+    count += b;
+    input = ygold;
+  }
+
+  *loss = -total / count;
+
+  free(ypred);
+  free(ynorm);
+}
+
+extern int enzyme_const;
+extern int enzyme_dup;
+extern int enzyme_dupnoneed;
+void __enzyme_autodiff(...) noexcept;
+
+// *      tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c
+
+void dlstm_objective_mayalias(int l, int c, int b, double const *main_params,
+                              double *dmain_params, double const *extra_params,
+                              double *dextra_params, double *state,
+                              double const *sequence, double *loss,
+                              double *dloss) {
+  __enzyme_autodiff(cxx_mayalias_lstm_objective, enzyme_const, l, enzyme_const,
+                    c, enzyme_const, b, enzyme_dup, main_params, dmain_params,
+                    enzyme_dup, extra_params, dextra_params, enzyme_const,
+                    state, enzyme_const, sequence, enzyme_dupnoneed, loss,
+                    dloss);
+}
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs b/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs
new file mode 100644
index 0000000..937460f
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs
@@ -0,0 +1,56 @@
+#![feature(autodiff)]
+
+pub (crate) mod unsf;
+pub (crate) mod safe;
+use std::slice;
+
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) {
+    let l = l as usize;
+    let c = c as usize;
+    let b = b as usize;
+    unsafe {unsf::lstm_unsafe_objective(l,c,b,main_params,extra_params,state,sequence, loss);}
+}
+#[no_mangle]
+pub extern "C" fn rust_safe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) {
+    let l = l as usize;
+    let c = c as usize;
+    let b = b as usize;
+    let (main_params, extra_params, state, sequence) = unsafe {(
+        slice::from_raw_parts(main_params, 2*l*4*b),
+        slice::from_raw_parts(extra_params, 3*b),
+        slice::from_raw_parts_mut(state, 2*l*b),
+        slice::from_raw_parts(sequence, c*b)
+    )};
+
+    unsafe {
+        safe::lstm_objective(l,c,b,main_params,extra_params,state,sequence, &mut *loss);
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_unsafe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) {
+    let l = l as usize;
+    let c = c as usize;
+    let b = b as usize;
+    unsafe {unsf::d_lstm_unsafe_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, res, d_res);}
+}
+#[no_mangle]
+pub extern "C" fn rust_safe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) {
+    let l = l as usize;
+    let c = c as usize;
+    let b = b as usize;
+    let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe {(
+        slice::from_raw_parts(main_params, 2*l*4*b),
+        slice::from_raw_parts_mut(d_main_params, 2*l*4*b),
+        slice::from_raw_parts(extra_params, 3*b),
+        slice::from_raw_parts_mut(d_extra_params, 3*b),
+        slice::from_raw_parts_mut(state, 2*l*b),
+        slice::from_raw_parts(sequence, c*b)
+    )};
+
+    unsafe {
+        safe::d_lstm_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, &mut *res, &mut *d_res);
+    }
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs b/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs
new file mode 100644
index 0000000..d6847a4
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs
@@ -0,0 +1,231 @@
+use std::slice;
+use std::autodiff::autodiff;
+
+// Sigmoid on scalar
+fn sigmoid(x: f64) -> f64 {
+    1.0 / (1.0 + (-x).exp())
+}
+
+// log(sum(exp(x), 2))
+#[inline]
+fn logsumexp(vect: &[f64]) -> f64 {
+    let mut sum = 0.0;
+    for &val in vect {
+        sum += val.exp();
+    }
+    sum += 2.0; // Adding 2 to sum
+    sum.ln()
+}
+
+// LSTM OBJECTIVE
+// The LSTM model
+fn lstm_model(
+    hsize: usize,
+    weight: &[f64],
+    bias: &[f64],
+    hidden: &mut [f64],
+    cell: &mut [f64],
+    input: &[f64],
+) {
+    let mut gates = vec![0.0; 4 * hsize];
+    let gates = &mut gates[..4 * hsize];
+    let (a, b) = gates.split_at_mut(2 * hsize);
+    let ((forget, ingate), (outgate, change)) = (a.split_at_mut(hsize), b.split_at_mut(hsize));
+
+    //debug_assert_eq!(weight.len(), 4 * hsize);
+    //debug_assert_eq!(bias.len(), 4 * hsize);
+    //debug_assert_eq!(hidden.len(), hsize);
+    //debug_assert!(cell.len() >= hsize);
+    //debug_assert!(input.len() >= hsize);
+    // caching input
+    for i in 0..hsize {
+        forget[i] = sigmoid(input[i] * weight[i] + bias[i]);
+        ingate[i] = sigmoid(hidden[i] * weight[hsize + i] + bias[hsize + i]);
+        outgate[i] = sigmoid(input[i] * weight[2 * hsize + i] + bias[2 * hsize + i]);
+        change[i] = (hidden[i] * weight[3 * hsize + i] + bias[3 * hsize + i]).tanh();
+    }
+
+    // caching cell
+    for i in 0..hsize {
+        cell[i] = cell[i] * forget[i] + ingate[i] * change[i];
+    }
+
+    for i in 0..hsize {
+        hidden[i] = outgate[i] * cell[i].tanh();
+    }
+}
+
+// Predict LSTM output given an input
+fn lstm_predict(
+    l: usize,
+    b: usize,
+    w: &[f64],
+    w2: &[f64],
+    s: &mut [f64],
+    x: &[f64],
+    x2: &mut [f64],
+) {
+    for i in 0..b {
+        x2[i] = x[i] * w2[i];
+    }
+
+    let mut i = 0;
+    while i <= 2 * l * b - 1 {
+        // make borrow-checker happy with non-overlapping mutable references
+        let (xp, s1, s2) = if i == 0 {
+            let (s1, s2) = s.split_at_mut(b);
+            (x2.as_mut(), s1, s2)
+        } else {
+            let tmp = &mut s[i - 2 * b..];
+            let (a, d) = tmp.split_at_mut(2 * b);
+            let (d, c) = d.split_at_mut(b);
+
+            (a, d, c)
+        };
+
+        lstm_model(
+            b,
+            &w[i * 4..(i + b) * 4],
+            &w[(i + b) * 4..(i + 2 * b) * 4],
+            s1,
+            s2,
+            xp,
+        );
+
+        i += 2 * b;
+    }
+
+    let xp = &s[i - 2 * b..];
+
+    for i in 0..b {
+        x2[i] = xp[i] * w2[b + i] + w2[2 * b + i];
+    }
+}
+
+// LSTM objective (loss function)
+#[autodiff(
+    d_lstm_objective,
+    Reverse,
+    Const,
+    Const,
+    Const,
+    Duplicated,
+    Duplicated,
+    Const,
+    Const,
+    DuplicatedOnly
+)]
+pub(crate) fn lstm_objective(
+    l: usize,
+    c: usize,
+    b: usize,
+    main_params: &[f64],
+    extra_params: &[f64],
+    state: &mut [f64],
+    sequence: &[f64],
+    loss: &mut f64,
+) {
+    let mut total = 0.0;
+
+    let mut input = &sequence[..b];
+    let mut ypred = vec![0.0; b];
+    let mut ynorm = vec![0.0; b];
+
+    //debug_assert!(b > 0);
+
+    let limit = (c - 1) * b;
+    for j in 0..(c - 1) {
+        let t = j * b;
+        lstm_predict(l, b, main_params, extra_params, state, input, &mut ypred);
+        let lse = logsumexp(&ypred);
+        for i in 0..b {
+            ynorm[i] = ypred[i] - lse;
+        }
+
+        let ygold = &sequence[t + b..];
+        for i in 0..b {
+            total += ygold[i] * ynorm[i];
+        }
+
+        input = ygold;
+    }
+    let count = (c - 1) * b;
+
+    *loss = -total / count as f64;
+}
+
+#[no_mangle]
+pub extern "C" fn rust_lstm_objective(
+    l: usize,
+    c: usize,
+    b: usize,
+    main_params: *const f64,
+    extra_params: *const f64,
+    state: *mut f64,
+    sequence: *const f64,
+    loss: *mut f64,
+) {
+    let (main_params, extra_params, state, sequence) = unsafe {
+        (
+            slice::from_raw_parts(main_params, 2 * l * 4 * b),
+            slice::from_raw_parts(extra_params, 3 * b),
+            slice::from_raw_parts_mut(state, 2 * l * b),
+            slice::from_raw_parts(sequence, c * b),
+        )
+    };
+
+    unsafe {
+        lstm_objective(
+            l,
+            c,
+            b,
+            main_params,
+            extra_params,
+            state,
+            sequence,
+            &mut *loss,
+        );
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dlstm_objective(
+    l: usize,
+    c: usize,
+    b: usize,
+    main_params: *const f64,
+    d_main_params: *mut f64,
+    extra_params: *const f64,
+    d_extra_params: *mut f64,
+    state: *mut f64,
+    sequence: *const f64,
+    res: *mut f64,
+    d_res: *mut f64,
+) {
+    let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe {
+        (
+            slice::from_raw_parts(main_params, 2 * l * 4 * b),
+            slice::from_raw_parts_mut(d_main_params, 2 * l * 4 * b),
+            slice::from_raw_parts(extra_params, 3 * b),
+            slice::from_raw_parts_mut(d_extra_params, 3 * b),
+            slice::from_raw_parts_mut(state, 2 * l * b),
+            slice::from_raw_parts(sequence, c * b),
+        )
+    };
+
+    unsafe {
+        d_lstm_objective(
+            l,
+            c,
+            b,
+            main_params,
+            d_main_params,
+            extra_params,
+            d_extra_params,
+            state,
+            sequence,
+            &mut *res,
+            &mut *d_res,
+        );
+    }
+}
diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs b/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs
new file mode 100644
index 0000000..498bf96
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs
@@ -0,0 +1,116 @@
+use std::autodiff::autodiff;
+
+// Sigmoid on scalar
+fn sigmoid(x: f64) -> f64 {
+    1.0 / (1.0 + (-x).exp())
+}
+
+// log(sum(exp(x), 2))
+unsafe fn logsumexp(vect: *const f64, sz: usize) -> f64 {
+    let mut sum: f64 = 0.0;
+    for i in 0..sz {
+        sum += (*vect.add(i)).exp();
+    }
+    sum += 2.0; // Adding 2 to sum
+    sum.ln()
+}
+
+// LSTM OBJECTIVE
+// The LSTM model
+unsafe fn lstm_model(
+    hsize: usize,
+    weight: *const f64,
+    bias: *const f64,
+    hidden: *mut f64,
+    cell: *mut f64,
+    input: *const f64,
+) {
+//    // TODO NOTE THIS
+//    //__builtin_assume(hsize > 0);
+    let mut gates = vec![0.0; 4 * hsize];
+    let forget: *mut f64 = gates.as_mut_ptr();
+    let ingate: *mut f64 = gates[hsize..].as_mut_ptr();
+    let outgate: *mut f64 = gates[2 * hsize..].as_mut_ptr();
+    let change: *mut f64 = gates[3 * hsize..].as_mut_ptr();
+	//let (a,b) = gates.split_at_mut(2*hsize);
+    //let ((forget, ingate), (outgate, change)) = (
+    //    a.split_at_mut(hsize), b.split_at_mut(hsize));
+
+    // caching input
+    for i in 0..hsize {
+        *forget.add(i) = sigmoid(*input.add(i) * *weight.add(i) + *bias.add(i));
+        *ingate.add(i) = sigmoid(*hidden.add(i) * *weight.add(hsize + i) + *bias.add(hsize + i));
+       *outgate.add(i) = sigmoid(*input.add(i) * *weight.add(2 * hsize + i) + *bias.add(2 * hsize + i));
+        *change.add(i) = (*hidden.add(i) * *weight.add(3 * hsize + i) + *bias.add(3 * hsize + i)).tanh();
+    }
+
+    // caching cell
+    for i in 0..hsize {
+        *cell.add(i) = *cell.add(i) * *forget.add(i) + *ingate.add(i) * *change.add(i);
+    }
+
+    for i in 0..hsize {
+        *hidden.add(i) = *outgate.add(i) * (*cell.add(i)).tanh();
+    }
+}
+
+// Predict LSTM output given an input
+unsafe fn lstm_predict(
+    l: usize,
+    b: usize,
+    w: *const f64,
+    w2: *const f64,
+    s: *mut f64,
+    x: *const f64,
+    x2: *mut f64,
+) {
+    for i in 0..b {
+        *x2.add(i) = *x.add(i) * *w2.add(i);
+    }
+
+    let mut xp = x2;
+    let stop = 2 * l * b;
+    for i in (0..=stop - 1).step_by(2 * b) {
+        lstm_model(b, w.add(i * 4), w.add((i + b) * 4), s.add(i), s.add(i + b), xp);
+        xp = s.add(i);
+    }
+
+    for i in 0..b {
+        *x2.add(i) = *xp.add(i) * *w2.add(b + i) + *w2.add(2 * b + i);
+    }
+}
+
+// LSTM objective (loss function)
+#[autodiff(d_lstm_unsafe_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Const, Const, DuplicatedOnly)]
+pub (crate) unsafe fn lstm_unsafe_objective(l: usize, c: usize, b: usize, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) {
+    let mut total = 0.0;
+    let mut count = 0;
+
+    //const double* input = &(sequence[0]);
+    let mut input = sequence;
+    let mut ypred = vec![0.0; b];
+    let mut ynorm = vec![0.0; b];
+    let mut lse;
+
+    assert!(b > 0);
+
+    let stop = (c - 1) * b;
+    for t in (0..=stop - 1).step_by(b) {
+        lstm_predict(l, b, main_params, extra_params, state, input, ypred.as_mut_ptr());
+        lse = logsumexp(ypred.as_mut_ptr(), b);
+        for i in 0..b {
+            ynorm[i] = ypred[i] - lse;
+        }
+
+        //let ygold = &sequence[t + b..];
+        let ygold = sequence.add(t + b);
+        for i in 0..b {
+            total += *ygold.add(i) * ynorm[i];
+        }
+
+        count += b;
+        input = ygold;
+    }
+
+    *loss = -total / count as f64;
+}
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock
new file mode 100644
index 0000000..93dcf6a
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock
@@ -0,0 +1,7 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "ode"
+version = "0.1.0"
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml
new file mode 100644
index 0000000..b7386a4
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml
@@ -0,0 +1,22 @@
+[package]
+name = "ode"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+
+[lib]
+crate-type = ["lib"]
+
+[profile.release]
+lto = "fat"
+opt-level = 3
+codegen-units = 1
+unwind = "abort"
+strip = true
+#overflow-checks = false
+
+[profile.dev]
+lto = "fat"
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make b/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make
index 5abb283..87af95f 100644
--- a/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make
+++ b/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make
@@ -6,6 +6,10 @@
 
 clean:
 	rm -f *.ll *.o results.txt results.json
+	cargo +enzyme clean
+
+$(dir)/benchmarks/ReverseMode/ode-real/target/release/libode.a: src/lib.rs Cargo.toml
+	RUSTFLAGS="-Z autodiff=Enable,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib
 
 %-unopt.ll: %.cpp
 	clang++ $(BENCH) $(PTR) $^ -O2 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm
@@ -16,17 +20,17 @@
 %-opt.ll: %-raw.ll
 	opt $^ -o $@ -S
 
-ode.o: ode-opt.ll
+ode.o: ode-opt.ll $(dir)/benchmarks/ReverseMode/ode-real/target/release/libode.a
 	clang++ $(BENCH) -O2 $^ -o $@ $(BENCHLINK)
 
 results.json: ode.o
-	./$^ 1000 | tee $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
-	./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 | tee $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
+	numactl -C 1 ./$^ 1000 >> $@
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp b/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp
index 7c7113d..17007c8 100644
--- a/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp
+++ b/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp
@@ -24,20 +24,8 @@
   return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec);
 }
 
-#define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS
-#define BOOST_NO_EXCEPTIONS
 #include <iostream>
-#include <boost/array.hpp>
-
-#include <boost/numeric/odeint.hpp>
-
-#include <boost/throw_exception.hpp>
-void boost::throw_exception(std::exception const & e){
-    //do nothing
-}
-
 using namespace std;
-using namespace boost::numeric::odeint;
 
 #define N 32
 #define xmin 0.
@@ -76,7 +64,7 @@
 }
 
 __attribute__((noinline))
-void brusselator_2d_loop(double* __restrict du, double* __restrict dv, const double* __restrict u, const double* __restrict v, const double* __restrict p, double t) {
+void brusselator_2d_loop_restrict(double* __restrict du, double* __restrict dv, const double* __restrict u, const double* __restrict v, const double* __restrict p, double t) {
   double A = p[0];
   double B = p[1];
   double alpha = p[2];
@@ -107,33 +95,131 @@
   }
 }
 
-typedef boost::array< double , 2 * N * N > state_type;
+__attribute__((noinline))
+void brusselator_2d_loop_norestrict(double* du, double* dv, const double* u, const double* v, const double* p, double t) {
+  double A = p[0];
+  double B = p[1];
+  double alpha = p[2];
+  double dx = (double)1/(N-1);
 
-void lorenz( const state_type &x , state_type &dxdt , double t )
+  alpha = alpha/(dx*dx);
+
+  for(int i=0; i<N; i++) {
+    for(int j=0; j<N; j++) {
+
+      double x = RANGE(xmin, xmax, i, N);
+      double y = RANGE(ymin, ymax, j, N);
+
+      unsigned ip1 = (i == N-1) ? i : (i+1);
+      unsigned im1 = (i == 0) ? i : (i-1);
+
+      unsigned jp1 = (j == N-1) ? j : (j+1);
+      unsigned jm1 = (j == 0) ? j : (j-1);
+
+      double u2v = GET(u, i, j) * GET(u, i, j) * GET(v, i, j);
+
+      GETnb(du, i, j) = alpha*( GET(u, im1, j) + GET(u, ip1, j) + GET(u, i, jp1) + GET(u, i, jm1) - 4 * GET(u, i, j))
+                      + B + u2v - (A + 1)*GET(u, i, j) + brusselator_f(x, y, t);
+
+      GETnb(dv, i, j) = alpha*( GET(v, im1, j) + GET(v, ip1, j) + GET(v, i, jp1) + GET(v, i, jm1) - 4 * GET(v, i, j))
+                      + A * GET(u, i, j) - u2v;
+    }
+  }
+}
+
+typedef double state_type[2*N*N];
+
+void lorenz_norestrict( const state_type &x, state_type &dxdt, double t )
 {
     // Extract the parameters
   double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. };
-  brusselator_2d_loop(dxdt.c_array(), dxdt.c_array() + N * N, x.data(), x.data() + N * N, p, t);
+  brusselator_2d_loop_norestrict(dxdt, dxdt + N * N, x, x + N * N, p, t);
 }
 
-// init_brusselator(x.c_array(), x.c_array() + N*N)
+void lorenz_restrict( const state_type &x, state_type &dxdt, double t )
+{
+    // Extract the parameters
+  double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. };
+  brusselator_2d_loop_restrict(dxdt, dxdt + N * N, x, x + N * N, p, t);
+}
 
-double foobar(const double* p, const state_type x, const state_type adjoint, double t) {
+extern "C" void rust_lorenz_safe(const double* x, double* dxdt, double t);
+extern "C" void rust_dbrusselator_2d_loop_safe(double* adjoint, const double* x, double* dx, const double* p, double* dp, double t);
+extern "C" void rust_lorenz_unsf(const double* x, double* dxdt, double t);
+extern "C" void rust_dbrusselator_2d_loop_unsf(double* adjoint, const double* x, double* dx, const double* p, double* dp, double t);
+
+double rustfoobar_unsf(const double *p, const state_type x, const state_type adjoint, double t) {
+  double dp[3] = { 0. };
+
+  state_type dx = { 0. };
+
+  state_type dadjoint_inp;// = adjoint
+  for (int i = 0; i < N * N; i++) {
+    dadjoint_inp[i] = adjoint[i];
+  }
+
+  rust_dbrusselator_2d_loop_unsf(dadjoint_inp, x, dx, p, dp, t);
+  return dx[0];
+}
+
+double rustfoobar_safe(const double *p, const state_type x, const state_type adjoint, double t) {
+  double dp[3] = { 0. };
+
+  state_type dx = { 0. };
+
+  state_type dadjoint_inp;// = adjoint
+  for (int i = 0; i < N * N; i++) {
+    dadjoint_inp[i] = adjoint[i];
+  }
+
+  rust_dbrusselator_2d_loop_safe(dadjoint_inp, x, dx, p, dp, t);
+  return dx[0];
+}
+
+double foobar_restrict(const double* p, const state_type x, const state_type adjoint, double t) {
     double dp[3] = { 0. };
 
     state_type dx = { 0. };
 
-    state_type dadjoint_inp = adjoint;
+    state_type dadjoint_inp;// = adjoint
+    for (int i = 0; i < N * N; i++) {
+      dadjoint_inp[i] = adjoint[i];
+    }
 
     state_type dxdu;
 
-    __enzyme_autodiff<void>(brusselator_2d_loop,
-//                            enzyme_dup, dxdu.c_array(), dadjoint_inp.c_array(),
-//                            enzyme_dup, dxdu.c_array() + N * N, dadjoint_inp.c_array() + N * N,
-                            enzyme_dupnoneed, nullptr, dadjoint_inp.data(),
-                            enzyme_dupnoneed, nullptr, dadjoint_inp.data() + N * N,
-                            enzyme_dup, x.data(), dx.data(),
-                            enzyme_dup, x.data() + N * N, dx.data() + N * N,
+    __enzyme_autodiff<void>(brusselator_2d_loop_restrict,
+                            enzyme_dup, dxdu, dadjoint_inp,
+                            enzyme_dup, dxdu + N * N, dadjoint_inp + N * N,
+ //                           enzyme_dupnoneed, nullptr, dadjoint_inp,
+ //                           enzyme_dupnoneed, nullptr, dadjoint_inp + N * N,
+                            enzyme_dup, x, dx,
+                            enzyme_dup, x + N * N, dx + N * N,
+                            enzyme_dup, p, dp,
+                            enzyme_const, t);
+
+    return dx[0];
+}
+
+double foobar_norestrict(const double* p, const state_type x, const state_type adjoint, double t) {
+    double dp[3] = { 0. };
+
+    state_type dx = { 0. };
+
+    state_type dadjoint_inp;// = adjoint
+    for (int i = 0; i < N * N; i++) {
+      dadjoint_inp[i] = adjoint[i];
+    }
+
+    state_type dxdu;
+
+    __enzyme_autodiff<void>(brusselator_2d_loop_norestrict,
+                            enzyme_dup, dxdu, dadjoint_inp,
+                            enzyme_dup, dxdu + N * N, dadjoint_inp + N * N,
+ //                           enzyme_dupnoneed, nullptr, dadjoint_inp,
+ //                           enzyme_dupnoneed, nullptr, dadjoint_inp + N * N,
+                            enzyme_dup, x, dx,
+                            enzyme_dup, x + N * N, dx + N * N,
                             enzyme_dup, p, dp,
                             enzyme_const, t);
 
@@ -486,14 +572,17 @@
 
     state_type dx = { 0. };
 
-    state_type dadjoint_inp = adjoint;
+    state_type dadjoint_inp;// = adjoint
+    for (int i = 0; i < N * N; i++) {
+      dadjoint_inp[i] = adjoint[i];
+    }
 
     state_type dxdu;
 
-    brusselator_2d_loop_b(nullptr, dadjoint_inp.data(),
-                          nullptr, dadjoint_inp.data() + N * N,
-                          x.data(), dx.data(),
-                          x.data() + N * N, dx.data() + N * N,
+    brusselator_2d_loop_b(nullptr, dadjoint_inp,
+                          nullptr, dadjoint_inp + N * N,
+                          x, dx,
+                          x + N * N, dx + N * N,
                           p, dp,
                           t);
 
@@ -505,10 +594,10 @@
   const double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. };
 
   state_type x;
-  init_brusselator(x.data(), x.data() + N * N);
+  init_brusselator(x, x + N * N);
 
   state_type adjoint;
-  init_brusselator(adjoint.data(), adjoint.data() + N * N);
+  init_brusselator(adjoint, adjoint + N * N);
 
   double t = 2.1;
 
@@ -542,174 +631,97 @@
 
   double res;
   for(int i=0; i<10000; i++)
-  res = foobar(p, x, adjoint, t);
+  res = foobar_norestrict(p, x, adjoint, t);
 
   gettimeofday(&end, NULL);
-  printf("Enzyme combined %0.6f res=%f\n", tdiff(&start, &end), res);
+  printf("C++  Enzyme combined mayalias %0.6f res=%f\n", tdiff(&start, &end), res);
   }
+  
+  {
+  struct timeval start, end;
+  gettimeofday(&start, NULL);
+
+  double res;
+  for(int i=0; i<10000; i++)
+  res = foobar_restrict(p, x, adjoint, t);
+
+  gettimeofday(&end, NULL);
+  printf("C++  Enzyme combined restrict %0.6f res=%f\n", tdiff(&start, &end), res);
+  }
+  
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+
+    double res;
+    for(int i=0; i<10000; i++)
+    res = rustfoobar_safe(p, x, adjoint, t);
+
+    gettimeofday(&end, NULL);
+    printf("Rust Enzyme combined safe %0.6f res=%f\n", tdiff(&start, &end), res);
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+
+    double res;
+    for(int i=0; i<10000; i++)
+    res = rustfoobar_unsf(p, x, adjoint, t);
+
+    gettimeofday(&end, NULL);
+    printf("Rust Enzyme combined unsf %0.6f res=%f\n", tdiff(&start, &end), res);
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    state_type x2;
+
+    for(int i=0; i<10000; i++) {
+      lorenz_norestrict(x, x2, t);
+    }
+
+    gettimeofday(&end, NULL);
+    printf("C++  fwd mayalias %0.6f res=%f\n", tdiff(&start, &end), x2[0]);
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    state_type x2;
+
+    for(int i=0; i<10000; i++) {
+      lorenz_restrict(x, x2, t);
+    }
+
+    gettimeofday(&end, NULL);
+    printf("C++  fwd restrict %0.6f res=%f\n", tdiff(&start, &end), x2[0]);
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    state_type x2;
+
+    for(int i=0; i<10000; i++)
+    rust_lorenz_safe(x, x2, t);
+
+    gettimeofday(&end, NULL);
+    printf("Rust fwd safe %0.6f res=%f\n\n", tdiff(&start, &end), x2[0]);
+  }
+
+  {
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    state_type x2;
+
+    for(int i=0; i<10000; i++)
+    rust_lorenz_unsf(x, x2, t);
+
+    gettimeofday(&end, NULL);
+    printf("Rust fwd unsf %0.6f res=%f\n\n", tdiff(&start, &end), x2[0]);
+  }
+
   //printf("res=%f\n", foobar(1000));
 }
-
-
-#if 0
-
-typedef boost::array< double , 6 > state_type;
-
-void lorenz( const state_type &x , state_type &dxdt , double t )
-{
-    // Extract the parameters
-    double k1 = x[3];
-    double k2 = x[4];
-    double k3 = x[5];
-
-    dxdt[0] = -k1 * x[0] + k3 * x[1] * x[2];
-    dxdt[1] = k1 * x[0] - k2 * x[1] * x[1] - k3 * x[1] * x[2];
-    dxdt[2] = k2 * x[1] * x[1];
-
-    // Don't change the parameters p
-    dxdt[3] = 0;
-    dxdt[4] = 0;
-    dxdt[5] = 0;
-}
-
-double foobar(double* p, uint64_t iters) {
-    state_type x = { 1.0, 0, 0, p[0], p[1], p[2] }; // initial conditions
-    double t = 1e5;
-    typedef controlled_runge_kutta< runge_kutta_dopri5< state_type , typename state_type::value_type , state_type , double > > stepper_type;
-    //typedef euler< state_type , typename state_type::value_type , state_type , double > stepper_type;
-    integrate_const( stepper_type(), lorenz , x , 0.0 , t, t/iters );
-
-    return x[0];
-}
-
-typedef boost::array< adouble , 6 > astate_type;
-
-void alorenz( const astate_type &x , astate_type &dxdt , adouble t )
-{
-    // Extract the parameters
-    adouble k1 = x[3];
-    adouble k2 = x[4];
-    adouble k3 = x[5];
-
-    dxdt[0] = -k1 * x[0] + k3 * x[1] * x[2];
-    dxdt[1] = k1 * x[0] - k2 * x[1] * x[1] - k3 * x[1] * x[2];
-    dxdt[2] = k2 * x[1] * x[1];
-
-    // Don't change the parameters p
-    dxdt[3] = 0;
-    dxdt[4] = 0;
-    dxdt[5] = 0;
-}
-
-adouble afoobar(adouble* p, uint64_t iters) {
-    astate_type x = { 1.0, 0, 0, p[0], p[1], p[2] }; // initial conditions
-    double t = 1e5;
-    typedef controlled_runge_kutta< runge_kutta_dopri5< astate_type , typename astate_type::value_type , astate_type , adouble > > stepper_type;
-    //typedef euler< astate_type , typename astate_type::value_type , astate_type , adouble > stepper_type;
-    integrate_const( stepper_type(), alorenz , x , 0.0 , t, t/iters );
-
-    return x[0];
-}
-
-static
-double afoobar_and_gradient(double* p_in, double* dp_out, uint64_t iters) {
-    adept::Stack stack;
-    adouble x[3] = { p_in[0], p_in[1], p_in[2] };
-    stack.new_recording();
-    adouble y = afoobar(x, iters);
-    y.set_gradient(1.0);
-    stack.compute_adjoint();
-    for(int i=0; i<3; i++)
-      dp_out[i] = x[i].get_gradient();
-    return y.value();
-}
-
-static void adept_sincos(uint64_t iters) {
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
-  double p[3] = { 0.04,3e7,1e4 };
-  double res = foobar(p, iters);
-
-  gettimeofday(&end, NULL);
-  printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res);
-  }
-
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
-  adept::Stack stack;
-  adouble p[3] = { 0.04,3e7,1e4 };
- // stack.new_recording();
-  adouble resa = afoobar(p, iters);
-  double res = resa.value();
-
-  gettimeofday(&end, NULL);
-  printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res);
-  }
-
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
-  double p[3] = { 0.04,3e7,1e4 };
-  double dp[3] = { 0 };
-  afoobar_and_gradient(p, dp, iters);
-
-  gettimeofday(&end, NULL);
-  printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), dp[0]);
-  }
-}
-
-static void enzyme_sincos(double inp, uint64_t iters) {
-
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
-  double p[3] = { 0.04,3e7,1e4 };
-  double res = foobar(p, iters);
-
-  gettimeofday(&end, NULL);
-  printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res);
-  }
-
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
-  double p[3] = { 0.04,3e7,1e4 };
-  double res = foobar(p, iters);
-
-  gettimeofday(&end, NULL);
-  printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res);
-  }
-
-  {
-  struct timeval start, end;
-  gettimeofday(&start, NULL);
-
-  double p[3] = { 0.04,3e7,1e4 };
-  double dp[3] = { 0 };
-  __enzyme_autodiff<void>(foobar, p, dp, iters);
-
-  gettimeofday(&end, NULL);
-  printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), dp[0]);
-  }
-}
-
-int main(int argc, char** argv) {
-
-  int max_iters = atoi(argv[1]) ;
-  double inp = 2.1;
-
-  //for(int iters=max_iters/20; iters<=max_iters; iters+=max_iters/20) {
-  auto iters = max_iters;
-    printf("iters=%d\n", iters);
-    adept_sincos(inp, iters);
-    enzyme_sincos(inp, iters);
-  //}
-}
-#endif
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs
new file mode 100644
index 0000000..4fbc7e7
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs
@@ -0,0 +1,100 @@
+#![feature(autodiff)]
+#![feature(slice_as_chunks)]
+#![feature(iter_next_chunk)]
+#![feature(array_ptr_get)]
+#![allow(non_snake_case)]
+#![allow(non_camel_case_types)]
+#![allow(non_upper_case_globals)]
+
+pub mod safe;
+pub mod unsf;
+
+type StateType = [f64; 2 * N * N];
+
+const N: usize = 32;
+
+
+#[no_mangle]
+pub extern "C" fn rust_lorenz_unsf(x: *const StateType, dxdt: *mut StateType, t: f64) {
+    let x: &StateType = unsafe { &*x };
+    let dxdt: &mut StateType = unsafe { &mut *dxdt };
+    unsafe {unsf::lorenz(x, dxdt, t)};
+}
+
+
+#[no_mangle]
+pub extern "C" fn rust_lorenz_safe(x: *const StateType, dxdt: *mut StateType, t: f64) {
+    let x: &StateType = unsafe { &*x };
+    let dxdt: &mut StateType = unsafe { &mut *dxdt };
+    safe::lorenz(x, dxdt, t);
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dbrusselator_2d_loop_unsf(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) {
+    let mut null1 = [0.; 1 * N * N];
+    let mut null2 = [0.; 1 * N * N];
+    let dx1: *mut f64 = dx.as_mut_ptr();
+    let dx2: *mut f64 = unsafe { dx.as_mut_ptr().add(N*N) };
+    let dadj1: *mut f64 = adjoint.as_mut_ptr();
+    let dadj2: *mut f64 = unsafe { adjoint.as_mut_ptr().add(N*N) };
+    let x1: *const f64 = x.as_ptr();
+    let x2: *const f64 = unsafe { x.as_ptr().add(N*N) };
+
+    unsafe {unsf::dbrusselator_2d_loop_unsf(null1.as_mut_ptr(), dadj1,
+                         null2.as_mut_ptr(), dadj2,
+                         x1, dx1, 
+                         x2, dx2,
+                         p as *mut f64, dp as *mut f64, t)};
+}
+
+#[no_mangle]
+pub extern "C" fn rust_dbrusselator_2d_loop_safe(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) {
+    let x: &StateType = unsafe { &*x };
+    let dx: &mut StateType = unsafe { &mut *dx };
+    let adjoint: &mut StateType = unsafe { &mut *adjoint };
+
+    let p: &[f64;3] = unsafe { &*p };
+    let dp: &mut [f64;3] = unsafe { &mut *dp };
+ 
+    assert!(p[0] == 3.4);
+    assert!(p[1] == 1.);
+    assert!(p[2] == 10.);
+    assert!(t == 2.1);
+
+    //let mut x1 = [0.; 2 * N * N];
+    //let mut dx1 = [0.; 2 *N * N];
+    //let (tmp1, tmp2) = x1.split_at_mut(N * N);
+    //let mut x1: [f64; N * N] = tmp1.try_into().unwrap();
+    //let mut x2: [f64; N * N] = tmp2.try_into().unwrap();
+    //init_brusselator(&mut x1, &mut x2);
+    //for i in 0..N*N {
+    //    let tmp = (x1[i] - x[i]).abs();
+    //    if (tmp / x[i] > 1e-5) {
+    //        dbg!(tmp);
+    //        dbg!(tmp / x[i]);
+    //        dbg!(i);
+    //        dbg!(x1[i]);
+    //        dbg!(x[i]);
+    //        println!("x1[{}] = {} != x[{}] = {}", i, x1[i], i, x[i]);
+    //        panic!();
+    //    }
+    //}
+
+    // Alternative ways to split the inputs
+    //let [ mut dx1, mut dx2]: [[f64; N*N]; 2] = unsafe { *std::mem::transmute::<*mut StateType, &mut [[f64; N*N]; 2]>(dx) };
+    //let [dx1, dx2]: &mut [[f64; N*N];2] = unsafe { dx.cast::<[[f64; N*N]; 2]>().as_mut().unwrap() };
+
+    // https://discord.com/channels/273534239310479360/273541522815713281/1236945105601040446
+    let ([dx1, dx2], []): (&mut [[f64; N*N]], &mut [f64]) = dx.as_chunks_mut() else { unreachable!() };
+    let ([dadj1, dadj2], []): (&mut [[f64; N*N]], &mut [f64])= adjoint.as_chunks_mut() else { unreachable!() };
+    let ([x1, x2], []): (&[[f64; N*N]], &[f64])= x.as_chunks() else { unreachable!() };
+    
+    let mut null1 = [0.; 1 * N * N];
+    let mut null2 = [0.; 1 * N * N];
+    safe::dbrusselator_2d_loop(&mut null1, dadj1,
+                         &mut null2, dadj2,
+                         x1, dx1, 
+                         x2, dx2,
+                         p, dp, t);
+    return;
+}
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs
new file mode 100644
index 0000000..ddf3685
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs
@@ -0,0 +1,75 @@
+use std::autodiff::autodiff;
+
+const N: usize = 32;
+const xmin: f64 = 0.;
+const xmax: f64 = 1.;
+const ymin: f64 = 0.;
+const ymax: f64 = 1.;
+
+#[inline(always)]
+fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 {
+    (max - min) / (N_var as f64 - 1.) * i as f64 + min
+}
+
+fn brusselator_f(x: f64, y: f64, t: f64) -> f64 {
+    let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1;
+    let eq2 = t >= 1.1;
+    if eq1 && eq2 {
+        5.0
+    } else {
+        0.0
+    }
+}
+
+#[expect(unused)]
+fn init_brusselator(u: &mut [f64], v: &mut [f64]) {
+    assert!(u.len() == N * N);
+    assert!(v.len() == N * N);
+    for i in 0..N {
+        for j in 0..N {
+            let x = range(xmin, xmax, i, N);
+            let y = range(ymin, ymax, j, N);
+            u[N * i + j] = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt();
+            v[N * i + j] = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt();
+        }
+    }
+}
+
+#[no_mangle]
+#[autodiff(dbrusselator_2d_loop, Reverse, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)]
+pub fn brusselator_2d_loop(d_u: &mut [f64;N*N], d_v: &mut [f64;N*N], u: &[f64;N*N], v: &[f64;N*N], p: &[f64;3], t: f64) {
+    let A = p[0];
+    let B = p[1];
+    let alpha = p[2];
+    let dx = 1. / (N - 1) as f64;
+    let alpha = alpha / (dx * dx);
+    for i in 0..N {
+        for j in 0..N {
+            let x = range(xmin, xmax, i, N);
+            let y = range(ymin, ymax, j, N);
+            let ip1 = if i == N - 1 { i } else { i + 1 };
+            let im1 = if i == 0 { i } else { i - 1 };
+            let jp1 = if j == N - 1 { j } else { j + 1 };
+            let jm1 = if j == 0 { j } else { j - 1 };
+            let u2v = u[N * i + j] * u[N * i + j] * v[N * i + j];
+            d_u[N * i + j] = alpha * (u[N * im1 + j] + u[N * ip1 + j] + u[N * i + jp1] + u[N * i + jm1] - 4. * u[N * i + j])
+                + B + u2v - (A + 1.) * u[N * i + j] + brusselator_f(x, y, t);
+            d_v[N * i + j] = alpha * (v[N * im1 + j] + v[N * ip1 + j] + v[N * i + jp1] + v[N * i + jm1] - 4. * v[N * i + j])
+                + A * u[N * i + j] - u2v;
+        }
+    }
+}
+
+pub type StateType = [f64; 2 * N * N];
+
+pub fn lorenz(x: &StateType, dxdt: &mut StateType, t: f64) {
+    let p = [3.4, 1., 10.];
+    let (tmp1, tmp2) = dxdt.split_at_mut(N * N);
+    let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap();
+    let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap();
+    let (tmp1, tmp2) = x.split_at(N * N);
+    let u: [f64; N * N] = tmp1.try_into().unwrap();
+    let v: [f64; N * N] = tmp2.try_into().unwrap();
+    brusselator_2d_loop(&mut dxdt1, &mut dxdt2, &u, &v, &p, t);
+}
+
diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs
new file mode 100644
index 0000000..9f1e400
--- /dev/null
+++ b/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs
@@ -0,0 +1,79 @@
+use std::autodiff::autodiff;
+
+const N: usize = 32;
+const xmin: f64 = 0.;
+const xmax: f64 = 1.;
+const ymin: f64 = 0.;
+const ymax: f64 = 1.;
+
+#[inline(always)]
+fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 {
+    (max - min) / (N_var as f64 - 1.) * i as f64 + min
+}
+
+fn brusselator_f(x: f64, y: f64, t: f64) -> f64 {
+    let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1;
+    let eq2 = t >= 1.1;
+    if eq1 && eq2 {
+        5.0
+    } else {
+        0.0
+    }
+}
+
+#[expect(unused)]
+unsafe fn init_brusselator(u: *mut f64, v: *mut f64) {
+    for i in 0..N {
+        for j in 0..N {
+            let x = range(xmin, xmax, i, N);
+            let y = range(ymin, ymax, j, N);
+            *u.add(N * i + j) = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt();
+            *v.add(N * i + j) = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt();
+        }
+    }
+}
+
+#[no_mangle]
+#[autodiff(dbrusselator_2d_loop_unsf, Reverse, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)]
+pub unsafe fn brusselator_2d_loop_unsf(d_u: *mut f64, d_v: *mut f64, u: *const f64, v: *const f64, p: *const f64, t: f64) {
+    let A = *p.add(0);
+    let B = *p.add(1);
+    let alpha = *p.add(2);
+    let dx = 1. / (N - 1) as f64;
+    let alpha = alpha / (dx * dx);
+    for i in 0..N {
+        for j in 0..N {
+            let x = range(xmin, xmax, i, N);
+            let y = range(ymin, ymax, j, N);
+            let ip1 = if i == N - 1 { i } else { i + 1 };
+            let im1 = if i == 0 { i } else { i - 1 };
+            let jp1 = if j == N - 1 { j } else { j + 1 };
+            let jm1 = if j == 0 { j } else { j - 1 };
+            let u2v = *u.add(N * i + j) * *u.add(N * i + j) * *v.add(N * i + j);
+            *d_u.add(N * i + j) = alpha * (*u.add(N * im1 + j) + *u.add(N * ip1 + j) + *u.add(N * i + jp1) + *u.add(N * i + jm1) - 4. * *u.add(N * i + j))
+                + B + u2v - (A + 1.) * *u.add(N * i + j) + brusselator_f(x, y, t);
+            *d_v.add(N * i + j) = alpha * (*v.add(N * im1 + j) + *v.add(N * ip1 + j) + *v.add(N * i + jp1) + *v.add(N * i + jm1) - 4. * *v.add(N * i + j))
+                + A * *u.add(N * i + j) - u2v;
+        }
+    }
+}
+
+type StateType = [f64; 2 * N * N];
+
+pub unsafe fn lorenz(x: *const StateType, dxdt: *mut StateType, t: f64) {
+    let p = [3.4, 1., 10.];
+    let x = x as *const f64;
+    let dxdt = dxdt as *mut f64;
+    let dxdt1: *mut f64 = dxdt as *mut f64;
+    let dxdt2: *mut f64 = unsafe {dxdt.add(N * N)} as *mut f64;
+    //let (tmp1, tmp2) = dxdt.split_at_mut(N * N);
+    //let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap();
+    //let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap();
+    let u: *const f64 = x as *const f64;
+    let v: *const f64 = unsafe{x.add(N * N)} as *const f64;
+    //let (tmp1, tmp2) = x.split_at(N * N);
+    //let u: [f64; N * N] = tmp1.try_into().unwrap();
+    //let v: [f64; N * N] = tmp2.try_into().unwrap();
+    unsafe {brusselator_2d_loop_unsf(dxdt1 as *mut f64, dxdt2 as *mut f64, u as *const f64, v as *const f64, p.as_ptr(), t)};
+}
+