| // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load |
| // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... |
| // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -; fi |
| // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -; fi |
| // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -; fi |
| // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -; fi |
| // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -; fi |
| // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -; fi |
| // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi |
| // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi |
| // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi |
| // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi |
| // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi |
| // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi |
| |
| #include "../blasinfra.h" |
| |
| int enzyme_dup; |
| int enzyme_out; |
| int enzyme_const; |
| template<typename ...T> |
| void __enzyme_autodiff(void*, T...); |
| |
| void my_dgemv(char layout, char trans, int M, int N, double alpha, double* __restrict__ A, int lda, double* __restrict__ X, int incx, double beta, double* __restrict__ Y, int incy) { |
| cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy); |
| inDerivative = true; |
| } |
| |
| void ow_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { |
| cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy); |
| inDerivative = true; |
| } |
| |
| |
| void my_dsymv(char layout, char uplo, int N, double alpha, double* __restrict__ A, int lda, double* __restrict__ X, int incx, double beta, double* __restrict__ Y, int incy) { |
| cblas_dsymv(layout, uplo, N, alpha, A, lda, X, incx, beta, Y, incy); |
| inDerivative = true; |
| } |
| |
| void ow_dsymv(char layout, char uplo, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { |
| cblas_dsymv(layout, uplo, N, alpha, A, lda, X, incx, beta, Y, incy); |
| inDerivative = true; |
| } |
| |
| |
| double my_ddot(int N, double* __restrict__ X, int incx, double* __restrict__ Y, int incy) { |
| double res = cblas_ddot(N, X, incx, Y, incy); |
| inDerivative = true; |
| return res; |
| } |
| |
| double my_dnrm2(int N, double *__restrict__ X, int incx) { |
| double res = cblas_dnrm2(N, X, incx); |
| inDerivative = true; |
| return res; |
| } |
| |
| void my_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* __restrict__ A, int lda, double* __restrict__ B, int ldb, double beta, double* __restrict__ C, int ldc) { |
| cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); |
| inDerivative = true; |
| } |
| |
| void ow_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { |
| cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); |
| inDerivative = true; |
| } |
| |
| void my_dtrmv(char layout, char uplo, char trans, |
| char diag, int N, double * __restrict__ A, int lda, |
| double *__restrict__ X, int incx) { |
| cblas_dtrmv(layout, uplo, trans, diag, N, A, lda, X, incx); |
| inDerivative = true; |
| } |
| |
| void my_dtrmm(char layout, char side, char uplo, |
| char trans, char diag, int M, int N, |
| double alpha, double * __restrict__ A, int lda, |
| double *__restrict B, int ldb) { |
| cblas_dtrmm(layout, side, uplo, trans, diag, M, N, alpha, A, lda, B, ldb); |
| inDerivative = true; |
| } |
| |
| void ow_dtrmm(char layout, char side, char uplo, |
| char trans, char diag, int M, int N, |
| double alpha, double * A, int lda, |
| double * B, int ldb) { |
| cblas_dtrmm(layout, side, uplo, trans, diag, M, N, alpha, A, lda, B, ldb); |
| inDerivative = true; |
| } |
| |
| void my_dsyrk(char layout, char uplo, char trans, |
| int N, int K, double alpha, |
| double *__restrict__ A, int lda, double beta, |
| double *__restrict__ C, int ldc) { |
| |
| cblas_dsyrk(layout, uplo, trans, N, K, alpha, A, lda, beta, |
| C, ldc); |
| inDerivative = true; |
| } |
| |
| void my_potrf(char layout, char uplo, int N, double *__restrict__ A, int lda) { |
| int info; |
| cblas_dpotrf(layout, uplo, N, A, lda, &info); |
| inDerivative = true; |
| } |
| void ow_potrf(char layout, char uplo, int N, double *__restrict__ A, int lda) { |
| int info; |
| cblas_dpotrf(layout, uplo, N, A, lda, &info); |
| cblas_dscal(1, 0.0, A, lda); |
| inDerivative = true; |
| } |
| |
| void my_potrs(char layout, char uplo, int N, int Nrhs, double *__restrict__ A, int lda, double *__restrict__ B, int ldb) { |
| int info; |
| cblas_dpotrs(layout, uplo, N, Nrhs, A, lda, B, ldb, &info); |
| inDerivative = true; |
| } |
| |
| void my_trtrs(char layout, char uplo, char trans, char diag, int N, int Nrhs, |
| double *__restrict__ A, int lda, double *__restrict__ B, |
| int ldb) { |
| int info; |
| cblas_dtrtrs(layout, uplo, trans, diag, N, Nrhs, A, lda, B, ldb, &info); |
| inDerivative = true; |
| } |
| void ow_trtrs(char layout, char uplo, char trans, char diag, int N, int Nrhs, |
| double *A, int lda, double *B, int ldb) { |
| int info; |
| cblas_dtrtrs(layout, uplo, trans, diag, N, Nrhs, A, lda, B, ldb, &info); |
| cblas_dscal(1, 0.0, A, lda); |
| inDerivative = true; |
| } |
| |
| void my_symm(char layout, char side, char uplo, |
| int M, int N, double alpha, |
| double * __restrict__ A, int lda, double * __restrict__ B, |
| int ldb, double beta, double * __restrict__ C, |
| int ldc) { |
| cblas_dsymm(layout, side, uplo, M, N, alpha, A, lda, B, ldb, beta, C, ldc); |
| inDerivative = true; |
| } |
| |
| void ow_symm(char layout, char side, char uplo, |
| int M, int N, double alpha, |
| double * A, int lda, double * B, |
| int ldb, double beta, double * C, |
| int ldc) { |
| cblas_dsymm(layout, side, uplo, M, N, alpha, A, lda, B, ldb, beta, C, ldc); |
| inDerivative = true; |
| } |
| |
| static void dotTests() { |
| |
| std::string Test = "DOT active both "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, N, incA), |
| /*B*/ BlasInfo(B, N, incB), |
| /*C*/ BlasInfo(C, M, incC), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| }; |
| init(); |
| my_ddot(N, A, incA, B, incB); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void*) my_ddot, |
| enzyme_const, N, |
| enzyme_dup, A, dA, |
| enzyme_const, incA, |
| enzyme_dup, B, dB, |
| enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| my_ddot(N, A, incA, B, incB); |
| |
| inDerivative = true; |
| |
| cblas_daxpy(N, 1.0, B, incB, dA, incA); |
| cblas_daxpy(N, 1.0, A, incA, dB, incB); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| } |
| |
| static void nrm2Tests() { |
| |
| std::string Test = "DNRM2 active both "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, N, incA), |
| /*B*/ BlasInfo(B, N, incB), |
| /*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(), |
| }; |
| init(); |
| my_dnrm2(N, A, incA); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void *)my_dnrm2, enzyme_const, N, enzyme_dup, A, dA, |
| enzyme_const, incA); |
| foundCalls = calls; |
| init(); |
| |
| my_dnrm2(N, A, incA); |
| |
| inDerivative = true; |
| |
| double tmp = cblas_dnrm2(N, A, incA); |
| cblas_daxpy(N, 1.0 / tmp, A, incA, dA, incA); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| } |
| |
| static void gemvTests() { |
| // N means normal matrix, T means transposed |
| for (char layout : { CblasRowMajor, CblasColMajor }) { |
| for (auto transA : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) { |
| // todo in fortran blas consider 'N', 'n', 'T', 't'} |
| |
| { |
| |
| bool trans = !is_normal(transA); |
| std::string Test = "GEMV active A, C "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, M, N, lda), |
| /*B*/ BlasInfo(B, trans ? M : N, incB), |
| /*C*/ BlasInfo(C, trans ? N : M, incC), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo() |
| }; |
| init(); |
| my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::GEMV); |
| assert(calls[0].pout_arg1 == C); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == B); |
| assert(calls[0].farg1 == alpha); |
| assert(calls[0].farg2 == beta); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == (char)transA); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == M); |
| assert(calls[0].iarg2 == N); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == incC); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void*) my_dgemv, |
| enzyme_const, layout, |
| enzyme_const, transA, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_const, B, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| // dC = alpha * X * transpose(Y) + A |
| cblas_dger(layout, M, N, alpha, trans ? B : dC, trans ? incB : incC, trans ? dC : B, trans ? incC : incB, dA, lda); |
| // dY = beta * dY |
| cblas_dscal(trans ? N : M, beta, dC, incC); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| |
| Test = "GEMV active A, B, C "; |
| |
| init(); |
| __enzyme_autodiff((void*) my_dgemv, |
| enzyme_const, layout, |
| enzyme_const, transA, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| // dC = alpha * X * transpose(Y) + A |
| cblas_dger(layout, M, N, alpha, trans ? B : dC, trans ? incB : incC, trans ? dC : B, trans ? incC : incB, dA, lda); |
| |
| // dB = alpha * trans(A) * dC + dB |
| cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A, lda, dC, incC, 1.0, dB, incB); |
| |
| // dY = beta * dY |
| cblas_dscal(trans ? N : M, beta, dC, incC); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| |
| |
| |
| Test = "GEMV active/overwrite"; |
| |
| init(); |
| __enzyme_autodiff((void*) ow_dgemv, |
| enzyme_const, layout, |
| enzyme_const, transA, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| assert(foundCalls.size() > 2); |
| auto A_cache = (double*)foundCalls[0].pout_arg1; |
| cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, M); |
| inputs[4] = BlasInfo(A_cache, layout, M, N, M); |
| auto B_cache = (double*)foundCalls[1].pout_arg1; |
| cblas_dcopy(trans ? M : N, B, incB, B_cache, 1); |
| inputs[5] = BlasInfo(B_cache, trans ? M : N, 1); |
| |
| ow_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| // dC = alpha * X * transpose(Y) + A |
| cblas_dger(layout, M, N, alpha, |
| trans ? B_cache : dC, |
| trans ? 1 : incC, |
| trans ? dC : B_cache, |
| trans ? incC : 1, dA, |
| lda); |
| |
| // dB = alpha * trans(A) * dC + dB |
| cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, M, dC, incC, 1.0, dB, incB); |
| |
| // dY = beta * dY |
| cblas_dscal(trans ? N : M, beta, dC, incC); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| |
| inputs[4] = BlasInfo(); |
| inputs[5] = BlasInfo(); |
| } |
| |
| |
| } |
| } |
| } |
| |
| |
| static void symvTests() { |
| int N = 17; |
| // N means normal matrix, T means transposed |
| for (char layout : { CblasRowMajor, CblasColMajor }) { |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| { |
| |
| std::string Test = "SYMV active A, C "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, N, N, lda), |
| /*B*/ BlasInfo(B, N, incB), |
| /*C*/ BlasInfo(C, N, incC), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo() |
| }; |
| |
| { |
| init(); |
| my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::SYMV); |
| assert(calls[0].pout_arg1 == C); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == B); |
| assert(calls[0].farg1 == alpha); |
| assert(calls[0].farg2 == beta); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == UNUSED_TRANS); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == N); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == incC); |
| assert(calls[0].uplo == uplo); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void*) my_dsymv, |
| enzyme_const, layout, |
| enzyme_const, uplo, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_const, B, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| |
| double *tmp = (double *)foundCalls[1].pout_arg1; |
| inputs[4] = BlasInfo(tmp, N, 1); |
| cblas_dcopy(N, dA, lda + 1, tmp, 1); |
| cblas_dsyr2(layout, uplo, N, alpha, B, incB, dC, incC, dA, lda); |
| cblas_dcopy(N, tmp, 1, dA, lda + 1); |
| |
| // dY = beta * dY |
| cblas_dscal(N, beta, dC, incC); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| |
| { |
| Test = "SYMV active A, B, C "; |
| |
| init(); |
| __enzyme_autodiff((void*) my_dsymv, |
| enzyme_const, layout, |
| enzyme_const, uplo, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| |
| double *tmp = (double *)foundCalls[1].pout_arg1; |
| inputs[4] = BlasInfo(tmp, N, 1); |
| cblas_dcopy(N, dA, lda + 1, tmp, 1); |
| cblas_dsyr2(layout, uplo, N, alpha, B, incB, dC, incC, dA, lda); |
| cblas_dcopy(N, tmp, 1, dA, lda + 1); |
| |
| cblas_dsymv(layout, uplo, N, alpha, A, lda, dC, incC, 1.0, dB, incB); |
| |
| // dY = beta * dY |
| cblas_dscal(N, beta, dC, incC); |
| |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| |
| { |
| |
| Test = "SYMV active/overwrite"; |
| |
| init(); |
| __enzyme_autodiff((void*) ow_dsymv, |
| enzyme_const, layout, |
| enzyme_const, uplo, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| assert(foundCalls.size() > 2); |
| auto A_cache = (double*)foundCalls[0].pout_arg1; |
| cblas_dlacpy(layout, uplo, N, N, A, lda, A_cache, N); |
| inputs[4] = BlasInfo(A_cache, layout, N, N, N); |
| auto B_cache = (double*)foundCalls[1].pout_arg1; |
| cblas_dcopy(N, B, incB, B_cache, 1); |
| inputs[5] = BlasInfo(B_cache, N, 1); |
| |
| ow_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| |
| double *tmp = (double *)foundCalls[3].pout_arg1; |
| inputs[3] = BlasInfo(tmp, N, 1); |
| cblas_dcopy(N, dA, lda + 1, tmp, 1); |
| cblas_dsyr2(layout, uplo, N, alpha, B_cache, 1, dC, incC, dA, lda); |
| cblas_dcopy(N, tmp, 1, dA, lda + 1); |
| |
| cblas_dsymv(layout, uplo, N, alpha, A_cache, N, dC, incC, 1.0, dB, incB); |
| |
| // dY = beta * dY |
| cblas_dscal(N, beta, dC, incC); |
| |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| |
| inputs[4] = BlasInfo(); |
| inputs[5] = BlasInfo(); |
| } |
| |
| |
| } |
| } |
| } |
| |
| static void gemmTests() { |
| // N means normal matrix, T means transposed |
| for (char layout : { CblasRowMajor, CblasColMajor }) { |
| for (auto transA : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) { |
| for (auto transB : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) { |
| // todo fortran blas {'N', 'n', 'T', 't'} |
| |
| { |
| |
| bool transA_bool = !is_normal(transA); |
| bool transB_bool = !is_normal(transB); |
| std::string Test = "GEMM"; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, transA_bool ? K : M, transA_bool ? M : K, lda), |
| /*B*/ BlasInfo(B, layout, transB_bool ? N : K , transB_bool ? K : N, incB), |
| /*C*/ BlasInfo(C, layout, M, N, incC), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo() |
| }; |
| init(); |
| my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::GEMM); |
| assert(calls[0].pout_arg1 == C); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == B); |
| assert(calls[0].farg1 == alpha); |
| assert(calls[0].farg2 == beta); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == (char)transA); |
| assert(calls[0].targ2 == (char)transB); |
| assert(calls[0].iarg1 == M); |
| assert(calls[0].iarg2 == N); |
| assert(calls[0].iarg3 == K); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == incC); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void*) my_dgemm, |
| enzyme_const, layout, |
| enzyme_const, transA, |
| enzyme_const, transB, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, K, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| |
| my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| |
| // dA = |
| my_dgemm(layout, |
| transA_bool ? (char)transB : (char)CBLAS_TRANSPOSE::CblasNoTrans, |
| transA_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transB), |
| transA_bool ? K : M, |
| transA_bool ? M : K, |
| N, |
| alpha, |
| transA_bool ? B : dC, |
| transA_bool ? incB : incC, |
| transA_bool ? dC : B, |
| transA_bool ? incC : incB, |
| 1.0, dA, lda); |
| |
| // dB = |
| my_dgemm(layout, |
| transB_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transA), |
| transB_bool ? (char)transA : (char)CBLAS_TRANSPOSE::CblasNoTrans, //transB, |
| transB_bool ? N : K, |
| transB_bool ? K : N, |
| M, |
| alpha, |
| transB_bool ? dC : A, |
| transB_bool ? incC : lda, |
| transB_bool ? A : dC, |
| transB_bool ? lda : incC, |
| 1.0, dB, incB); |
| |
| cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| |
| |
| Test = "GEMM overwrite"; |
| |
| init(); |
| __enzyme_autodiff((void*) ow_dgemm, |
| enzyme_const, layout, |
| enzyme_const, transA, |
| enzyme_const, transB, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, K, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| assert(foundCalls.size() > 2); |
| auto A_cache = (double*)foundCalls[0].pout_arg1; |
| cblas_dlacpy(layout, '\0', (!transA_bool) ? M : K, (!transA_bool) ? K : M, A, lda, A_cache, (!transA_bool) ? M : K); |
| inputs[4] = BlasInfo(A_cache, layout, (!transA_bool) ? M : K, (!transA_bool) ? K : M, (!transA_bool) ? M : K); |
| auto B_cache = (double*)foundCalls[1].pout_arg1; |
| cblas_dlacpy(layout, '\0', (!transB_bool) ? K : N, (!transB_bool) ? N : K, B, incB, B_cache, (!transB_bool) ? K : N); |
| inputs[5] = BlasInfo(B_cache, layout, (!transB_bool) ? K : N, (!transB_bool) ? N : K, (!transB_bool) ? K : N); |
| |
| ow_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| |
| // dA = |
| my_dgemm(layout, |
| transA_bool ? (char)transB : (char)CBLAS_TRANSPOSE::CblasNoTrans, |
| transA_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transB), |
| transA_bool ? K : M, |
| transA_bool ? M : K, |
| N, |
| alpha, |
| transA_bool ? B_cache : dC, |
| transA_bool ? ( (!transB_bool) ? K : N ) : incC, |
| transA_bool ? dC : B_cache, |
| transA_bool ? incC : ( (!transB_bool) ? K : N), |
| 1.0, dA, lda); |
| |
| // dB = |
| my_dgemm(layout, |
| transB_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transA), |
| transB_bool ? (char)transA : (char)CBLAS_TRANSPOSE::CblasNoTrans, //transB, |
| transB_bool ? N : K, |
| transB_bool ? K : N, |
| M, |
| alpha, |
| transB_bool ? dC : A_cache, |
| transB_bool ? incC : ( (!transA_bool) ? M : K), |
| transB_bool ? A_cache : dC, |
| transB_bool ? ( (!transA_bool) ? M : K) : incC, |
| 1.0, dB, incB); |
| |
| cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| } |
| |
| |
| } |
| } |
| } |
| } |
| |
| static void trmvTests() { |
| REALCOPY = true; |
| // N means normal matrix, T means transposed |
| int N = 7; |
| double* B = (double*)malloc(sizeof(double*)*incB*N); |
| double* dB = (double*)malloc(sizeof(double*)*incB*N); |
| // TODO row major |
| for (char layout : { CblasColMajor, /*CblasRowMajor */}) { |
| |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| |
| for (auto diag : {'U', 'u', 'N', 'n'}) |
| |
| for (auto transA : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) |
| |
| { |
| // todo in fortran blas consider 'N', 'n', 'T', 't'} |
| |
| { |
| |
| bool trans = !is_normal(transA); |
| std::string Test = "TRMV active A, C "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, N, N, lda), |
| /*B*/ /*BlasInfo(B, N, incB),*/BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(B, N, incB), |
| BlasInfo(dB, N, incB) |
| }; |
| init(); |
| |
| for (int i=0; i<N*incB; i++) { |
| B[i] = i *1e-4; |
| dB[i] = -i *1e-4; |
| } |
| for (size_t i=0; i<N; i++) { |
| B[incB*i] = 7 + i; |
| dB[incB*i] = 300 + i; |
| } |
| my_dtrmv(layout, uplo, (char)transA, diag, N, A, lda, B, incB); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::TRMV); |
| assert(calls[0].pout_arg1 == B); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == UNUSED_POINTER); |
| assert(calls[0].farg1 == UNUSED_DOUBLE); |
| assert(calls[0].farg2 == UNUSED_DOUBLE); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == (char)transA); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == N); |
| assert(calls[0].iarg2 == UNUSED_INT); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == UNUSED_INT); |
| assert(calls[0].uplo == uplo); |
| assert(calls[0].diag == diag); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| for (int i=0; i<N*incB; i++) { |
| B[i] = i *1e-4; |
| dB[i] = -i *1e-4; |
| } |
| for (size_t i=0; i<N; i++) { |
| B[incB*i] = 7 + i; |
| dB[incB*i] = 300 + i; |
| } |
| __enzyme_autodiff((void*) my_dtrmv, |
| enzyme_const, layout, |
| enzyme_const, uplo, |
| enzyme_const, transA, |
| enzyme_const, diag, |
| enzyme_const, N, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| for (int i=0; i<N*incB; i++) { |
| B[i] = i *1e-4; |
| dB[i] = -i *1e-4; |
| } |
| for (size_t i=0; i<N; i++) { |
| B[incB*i] = 7 + i; |
| dB[incB*i] = 300 + i; |
| } |
| assert(foundCalls.size() >= 2); |
| assert(foundCalls[0].type == CallType::COPY); |
| double* cacheB = (double*)foundCalls[0].pout_arg1; |
| |
| cblas_dcopy(N, B, incB, cacheB, 1); |
| inputs[3] = BlasInfo(cacheB, N, 1); |
| auto B0 = cacheB; |
| |
| my_dtrmv(layout, uplo, (char)transA, diag, N, A, lda, B, incB); |
| |
| inDerivative = true; |
| |
| auto d = (diag == 'n' || diag == 'N') ? 0 : 1; |
| |
| #define Aa(r,c) dA[(r-1)*(layout == CblasRowMajor ? lda : 1) + (c-1)*(layout == CblasRowMajor ? 1 : lda) ] |
| |
| if (is_normal(transA)) { |
| if (uplo == 'u' || uplo == 'U') { |
| for (int i=1; i<=N; i++) { |
| cblas_daxpy(i-d, B0[i-1], dB, incB, &Aa(1, i), 1); |
| } |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=N-d; i++) |
| cblas_daxpy(N-i+1-d, B0[i-1], &dB[(i+d-1)*incB], incB, &Aa(i+d,i), 1); |
| } |
| } else { |
| // BLAS operation |
| // x := A'*x where A is triangular |
| // RMD operation |
| // Aa += x*xa' |
| if( uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=N; i++) |
| cblas_daxpy(i-d, dB[(i-1)*incB], B0, 1, &Aa(1, i), 1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=N-d; i++) |
| cblas_daxpy(N-i+1-d, dB[(i-1)*incB], &B0[i+d-1], 1, &Aa(i+d,i), 1); |
| } |
| } |
| |
| cblas_dtrmv(layout, uplo, (char)transpose(transA), diag, N, A, lda, dB, incB); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| |
| } |
| |
| |
| } |
| } |
| REALCOPY = false; |
| } |
| |
| static void trmmTests() { |
| // N means normal matrix, T means transposed |
| // TODO: row major is presently an exepcted failure. We should re-enable. |
| for (char layout : { CblasColMajor, /*CblasRowMajor*/ }) { |
| |
| for (auto side : {'L', 'l', 'R', 'r'}) |
| |
| for (auto transA : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) |
| |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| |
| for (auto diag : {'U', 'u', 'N', 'n'}) |
| |
| { |
| // todo in fortran blas consider 'N', 'n', 'T', 't'} |
| |
| int N = 7; |
| int M = 13; |
| { |
| |
| bool trans = !is_normal(transA); |
| std::string Test = "TRMM active A, B "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, (side == 'L' || side == 'l') ? M : N, (side == 'L' || side == 'l') ? M : N, lda), |
| /*B*/ BlasInfo(B, layout, M, N, incB), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo() |
| }; |
| init(); |
| |
| my_dtrmm(layout, side, uplo, (char)transA, diag, M, N, alpha, A, lda, B, incB); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::TRMM); |
| assert(calls[0].pout_arg1 == B); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == UNUSED_POINTER); |
| assert(calls[0].farg1 == alpha); |
| assert(calls[0].farg2 == UNUSED_DOUBLE); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == (char)transA); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == M); |
| assert(calls[0].iarg2 == N); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == UNUSED_INT); |
| assert(calls[0].side == side); |
| assert(calls[0].uplo == uplo); |
| assert(calls[0].diag == diag); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void*) my_dtrmm, |
| enzyme_const, layout, |
| enzyme_const, side, |
| enzyme_const, uplo, |
| enzyme_const, transA, |
| enzyme_const, diag, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| |
| double* cacheB = (double*)foundCalls[0].pout_arg1; |
| |
| cblas_dlacpy(layout, '\0', M, N, |
| B, |
| incB, cacheB, M); |
| inputs[4] = BlasInfo(cacheB, layout, M, N, M); |
| my_dtrmm(layout, side, uplo, (char)transA, diag, M, N, alpha, A, lda, B, incB); |
| |
| assert(foundCalls.size() >= 2); |
| assert(foundCalls[0].type == CallType::LACPY); |
| inDerivative = true; |
| |
| auto d = (diag == 'n' || diag == 'N') ? 0 : 1; |
| |
| #define B0(r,c) cacheB[(r-1)*(layout == CblasRowMajor ? M : 1) + (c-1)*(layout == CblasRowMajor ? 1 : M) ] |
| #define Ba(r,c) dB[(r-1)*(layout == CblasRowMajor ? incB : 1) + (c-1)*(layout == CblasRowMajor ? 1 : incB) ] |
| #define Aa(r,c) dA[(r-1)*(layout == CblasRowMajor ? lda : 1) + (c-1)*(layout == CblasRowMajor ? 1 : lda) ] |
| |
| auto ldb = incB; |
| |
| char toTrans; |
| if (side == 'l') |
| toTrans = 'n'; |
| else if (side == 'L') |
| toTrans = 'N'; |
| else if (side == 'r') |
| toTrans = 't'; |
| else if (side == 'R') |
| toTrans = 'T'; |
| |
| if (side == 'l' || side == 'L') { |
| if (is_normal(transA)) { |
| // BLAS operation |
| // B = alpha*A*B0 |
| // RMD operation |
| // Aa += alpha*Ba*B0' |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=M; i++) |
| cblas_dgemv(layout, toTrans,i-d,N, alpha,dB,incB,&B0(i, 1),M,1.0,&Aa(1, i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=M-d; i++) |
| cblas_dgemv(layout, toTrans,M-i+1-d,N,alpha,&Ba(i+d,1),ldb,&B0(i,1),M,1.0, &Aa(i+d,i),1); |
| } |
| } else { |
| // BLAS operation |
| // B = alpha*A'*B0 |
| // RMD operation |
| // Aa += alpha*B*Ba' |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=M; i++) |
| cblas_dgemv(layout, toTrans,i-d,N, alpha,&B0(1,1),M,&Ba(i,1),ldb,1.0,&Aa(1,i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=M-d; i++) |
| cblas_dgemv(layout, toTrans,M-i+1-d,N,alpha,&B0(i+d,1),M,&Ba(i,1),ldb,1.0, &Aa(i+d,i),1); |
| } |
| } |
| } else { |
| if (is_normal(transA)) { |
| // BLAS operation |
| // B = alpha*B0*A |
| // RMD operation |
| // Aa += alpha*B0'*Ba |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=N; i++) |
| cblas_dgemv(layout, toTrans,M,i-d,alpha,&B0(1,1),M,&Ba(1,i),1, 1.0,&Aa(1,i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=N-d; i++) |
| cblas_dgemv(layout, toTrans,M,N-i+1-d,alpha,&B0(1,i+d),M,&Ba(1,i),1, 1.0, & |
| Aa(i+d,i),1); |
| } |
| } else { |
| // BLAS operation |
| // B = alpha*B0*A' |
| // RMD operation |
| // Aa += alpha*Ba'*B0 |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=N; i++) |
| cblas_dgemv(layout, toTrans,M,i-d,alpha,&Ba(1,1),ldb,&B0(1,i),1, 1.0,&Aa(1,i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=N-d; i++) |
| cblas_dgemv(layout, toTrans,M,N-i+1-d,alpha,&Ba(1,i+d),ldb,&B0(1,i),1, 1.0, &Aa(i+d,i),1); |
| } |
| } |
| } |
| |
| cblas_dtrmm(layout, side, uplo, (char)transpose(transA), diag, M, N, alpha, A, lda, dB, incB); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| |
| } |
| |
| { |
| |
| bool trans = !is_normal(transA); |
| std::string Test = "TRMM overwrite active A, B "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, (side == 'L' || side == 'l') ? M : N, (side == 'L' || side == 'l') ? M : N, lda), |
| /*B*/ BlasInfo(B, layout, M, N, incB), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo() |
| }; |
| init(); |
| |
| ow_dtrmm(layout, side, uplo, (char)transA, diag, M, N, alpha, A, lda, B, incB); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void*) ow_dtrmm, |
| enzyme_const, layout, |
| enzyme_const, side, |
| enzyme_const, uplo, |
| enzyme_const, transA, |
| enzyme_const, diag, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| |
| double* cacheA = (double*)foundCalls[0].pout_arg1; |
| |
| cblas_dlacpy(layout, '\0', is_left(side) ? M : N, is_left(side) ? M : N, |
| A, |
| lda, cacheA, is_left(side) ? M : N); |
| inputs[5] = BlasInfo(cacheA, layout, is_left(side) ? M : N, is_left(side) ? M : N, is_left(side) ? M : N); |
| |
| double* cacheB = (double*)foundCalls[1].pout_arg1; |
| |
| cblas_dlacpy(layout, '\0', M, N, |
| B, |
| incB, cacheB, M); |
| inputs[4] = BlasInfo(cacheB, layout, M, N, M); |
| |
| ow_dtrmm(layout, side, uplo, (char)transA, diag, M, N, alpha, A, lda, B, incB); |
| |
| assert(foundCalls.size() >= 2); |
| assert(foundCalls[0].type == CallType::LACPY); |
| inDerivative = true; |
| |
| auto d = (diag == 'n' || diag == 'N') ? 0 : 1; |
| |
| #define B0(r,c) cacheB[(r-1)*(layout == CblasRowMajor ? M : 1) + (c-1)*(layout == CblasRowMajor ? 1 : M) ] |
| #define Ba(r,c) dB[(r-1)*(layout == CblasRowMajor ? incB : 1) + (c-1)*(layout == CblasRowMajor ? 1 : incB) ] |
| #define Aa(r,c) dA[(r-1)*(layout == CblasRowMajor ? lda : 1) + (c-1)*(layout == CblasRowMajor ? 1 : lda) ] |
| |
| auto ldb = incB; |
| |
| char toTrans; |
| if (side == 'l') |
| toTrans = 'n'; |
| else if (side == 'L') |
| toTrans = 'N'; |
| else if (side == 'r') |
| toTrans = 't'; |
| else if (side == 'R') |
| toTrans = 'T'; |
| |
| if (side == 'l' || side == 'L') { |
| if (is_normal(transA)) { |
| // BLAS operation |
| // B = alpha*A*B0 |
| // RMD operation |
| // Aa += alpha*Ba*B0' |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=M; i++) |
| cblas_dgemv(layout, toTrans,i-d,N, alpha,dB,incB,&B0(i, 1),M,1.0,&Aa(1, i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=M-d; i++) |
| cblas_dgemv(layout, toTrans,M-i+1-d,N,alpha,&Ba(i+d,1),ldb,&B0(i,1),M,1.0, &Aa(i+d,i),1); |
| } |
| } else { |
| // BLAS operation |
| // B = alpha*A'*B0 |
| // RMD operation |
| // Aa += alpha*B*Ba' |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=M; i++) |
| cblas_dgemv(layout, toTrans,i-d,N, alpha,&B0(1,1),M,&Ba(i,1),ldb,1.0,&Aa(1,i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=M-d; i++) |
| cblas_dgemv(layout, toTrans,M-i+1-d,N,alpha,&B0(i+d,1),M,&Ba(i,1),ldb,1.0, &Aa(i+d,i),1); |
| } |
| } |
| } else { |
| if (is_normal(transA)) { |
| // BLAS operation |
| // B = alpha*B0*A |
| // RMD operation |
| // Aa += alpha*B0'*Ba |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=N; i++) |
| cblas_dgemv(layout, toTrans,M,i-d,alpha,&B0(1,1),M,&Ba(1,i),1, 1.0,&Aa(1,i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=N-d; i++) |
| cblas_dgemv(layout, toTrans,M,N-i+1-d,alpha,&B0(1,i+d),M,&Ba(1,i),1, 1.0, & |
| Aa(i+d,i),1); |
| } |
| } else { |
| // BLAS operation |
| // B = alpha*B0*A' |
| // RMD operation |
| // Aa += alpha*Ba'*B0 |
| if(uplo == 'u' || uplo == 'U') { |
| // A is upper triangular |
| for (int i=1; i<=N; i++) |
| cblas_dgemv(layout, toTrans,M,i-d,alpha,&Ba(1,1),ldb,&B0(1,i),1, 1.0,&Aa(1,i),1); |
| } else { |
| // A is lower triangular |
| for (int i=1; i<=N-d; i++) |
| cblas_dgemv(layout, toTrans,M,N-i+1-d,alpha,&Ba(1,i+d),ldb,&B0(1,i),1, 1.0, &Aa(i+d,i),1); |
| } |
| } |
| } |
| |
| cblas_dtrmm(layout, side, uplo, (char)transpose(transA), diag, M, N, alpha, cacheA, is_left(side) ? M : N, dB, incB); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| |
| } |
| |
| |
| } |
| } |
| } |
| |
| static void syrkTests() { |
| int N = 13; |
| int K = 7; |
| double *C = (double *)malloc(sizeof(double *) * incC * N * N); |
| double *dC = (double *)malloc(sizeof(double *) * incC * N * N); |
| // N means normal matrix, T means transposed |
| // TODO: row major is presently an exepcted failure. We should re-enable. |
| for (char layout : {CblasColMajor, /*CblasRowMajor*/}) { |
| |
| for (auto transA : |
| {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) |
| |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| |
| { |
| |
| { |
| |
| bool trans = !is_normal(transA); |
| std::string Test = "SYRK active C, B "; |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, trans ? K : N, trans ? N : K, lda), |
| /*B*/ BlasInfo(), |
| /*C*/ BlasInfo(), |
| BlasInfo(), |
| /*C*/ BlasInfo(C, layout, N, N, incC), |
| /*C*/ BlasInfo(dC, layout, N, N, incC), |
| }; |
| init(); |
| |
| for (int i = 0; i < N * N * incC; i++) { |
| C[i] = i * 1e-4; |
| dC[i] = -i * 1e-4; |
| } |
| for (size_t i = 0; i < N * N; i++) { |
| C[incC * i] = 7 + i; |
| dC[incC * i] = 300 + i; |
| } |
| my_dsyrk(layout, uplo, (char)transA, N, K, alpha, A, lda, beta, C, |
| incC); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::SYRK); |
| assert(calls[0].pout_arg1 == C); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == UNUSED_POINTER); |
| assert(calls[0].farg1 == alpha); |
| assert(calls[0].farg2 == beta); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == (char)transA); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == N); |
| assert(calls[0].iarg2 == K); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incC); |
| assert(calls[0].iarg6 == UNUSED_INT); |
| assert(calls[0].side == UNUSED_TRANS); |
| assert(calls[0].uplo == uplo); |
| assert(calls[0].diag == UNUSED_TRANS); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| for (int i = 0; i < N * N * incC; i++) { |
| C[i] = i * 1e-4; |
| dC[i] = -i * 1e-4; |
| } |
| for (size_t i = 0; i < N * N; i++) { |
| C[incC * i] = 7 + i; |
| dC[incC * i] = 300 + i; |
| } |
| __enzyme_autodiff( |
| (void *)my_dsyrk, enzyme_const, layout, enzyme_const, uplo, |
| enzyme_const, transA, enzyme_const, N, enzyme_const, K, |
| enzyme_const, alpha, enzyme_dup, A, dA, enzyme_const, lda, |
| enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| for (int i = 0; i < N * N * incC; i++) { |
| C[i] = i * 1e-4; |
| dC[i] = -i * 1e-4; |
| } |
| for (size_t i = 0; i < N * N; i++) { |
| C[incC * i] = 7 + i; |
| dC[incC * i] = 300 + i; |
| } |
| |
| my_dsyrk(layout, uplo, (char)transA, N, K, alpha, A, lda, beta, C, |
| incC); |
| |
| inDerivative = true; |
| |
| #define Av(r, c) \ |
| A[(r - 1) * (layout == CblasRowMajor ? lda : 1) + \ |
| (c - 1) * (layout == CblasRowMajor ? 1 : lda)] |
| #define Aa(r, c) \ |
| dA[(r - 1) * (layout == CblasRowMajor ? lda : 1) + \ |
| (c - 1) * (layout == CblasRowMajor ? 1 : lda)] |
| |
| #define Ca(r, c) \ |
| dC[(r - 1) * (layout == CblasRowMajor ? incC : 1) + \ |
| (c - 1) * (layout == CblasRowMajor ? 1 : incC)] |
| |
| if (is_normal(transA)) { |
| // BLAS operation |
| // C = alpha*A*A' + beta*C |
| // RMD op |
| // Aa += alpha*(Ca+diag(Ca))*A |
| cblas_dsymm(layout, 'l', uplo, N, K, alpha, dC, incC, A, lda, 1.0, |
| dA, lda); |
| for (int i = 1; i <= N; i++) |
| cblas_daxpy(K, alpha * Ca(i, i), &Av(i, 1), lda, &Aa(i, 1), lda); |
| } else { |
| // BLAS operation |
| // C = alpha*A'*A + beta*C |
| // RMD operation |
| // Aa += alpha*A*(Ca+diag(Ca)) |
| cblas_dsymm(layout, 'r', uplo, K, N, alpha, dC, incC, A, lda, 1.0, |
| dA, lda); |
| for (int i = 1; i <= N; i++) |
| cblas_daxpy(K, alpha * Ca(i, i), &Av(1, i), 1, &Aa(1, i), 1); |
| } |
| cblas_dlascl(layout, uplo, 0, 0, 1.0, beta, N, N, dC, incC, 0); |
| |
| checkTest(Test); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| } |
| } |
| } |
| free(C); |
| free(dC); |
| } |
| |
| static void potrfTests() { |
| int N = 17; |
| // N means normal matrix, T means transposed |
| for (char layout : {CblasColMajor, CblasRowMajor}) { |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| |
| { |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, N, N, lda), |
| /*B*/ BlasInfo(), |
| /*C*/ BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| }; |
| { |
| |
| std::string Test = "POTRF active A "; |
| init(); |
| |
| my_potrf(layout, uplo, N, A, lda); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::POTRF); |
| assert(calls[0].pout_arg1 == A); |
| assert(calls[0].pin_arg1 == UNUSED_POINTER); |
| assert(calls[0].pin_arg2 == UNUSED_POINTER); |
| assert(calls[0].farg1 == UNUSED_DOUBLE); |
| assert(calls[0].farg2 == UNUSED_DOUBLE); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == UNUSED_TRANS); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == N); |
| assert(calls[0].iarg2 == UNUSED_INT); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == UNUSED_INT); |
| assert(calls[0].iarg6 == UNUSED_INT); |
| assert(calls[0].side == UNUSED_TRANS); |
| assert(calls[0].uplo == uplo); |
| assert(calls[0].diag == UNUSED_TRANS); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void *)my_potrf, enzyme_const, layout, enzyme_const, |
| uplo, enzyme_const, N, enzyme_dup, A, dA, |
| enzyme_const, lda); |
| foundCalls = calls; |
| init(); |
| |
| my_potrf(layout, uplo, N, A, lda); |
| |
| inDerivative = true; |
| |
| assert(foundCalls.size() >= 2); |
| assert(foundCalls[1].type == CallType::LACPY); |
| double *tri = (double *)foundCalls[1].pout_arg1; |
| inputs[3] = BlasInfo(tri, layout, N, N, N); |
| |
| cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N); |
| |
| cblas_dtrmm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0, |
| A, lda, tri, N); |
| |
| assert(foundCalls.size() >= 5); |
| assert(foundCalls[3].type == CallType::COPY); |
| double *tmp = (double *)foundCalls[3].pout_arg1; |
| inputs[4] = BlasInfo(tmp, N, 1); |
| |
| cblas_dcopy(N, tri, N + 1, tmp, 1); |
| cblas_dscal(N, 0.5, tmp, 1); |
| cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0); |
| cblas_dcopy(N, tmp, 1, tri, N + 1); |
| |
| cblas_dtrsm(layout, uplo_to_rside(uplo), uplo, 'N', 'N', N, N, 1.0, |
| A, lda, tri, N); |
| cblas_dtrsm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0, |
| A, lda, tri, N); |
| #define triv(r, c) \ |
| tri[(r) * (layout == CblasRowMajor ? N : 1) + \ |
| (c) * (layout == CblasRowMajor ? 1 : N)] |
| |
| int upperinc = (&triv(0, 1) - &triv(0,0)); |
| int lowerinc = (&triv(1, 0) - &triv(0,0)); |
| if (layout == CblasColMajor) { |
| assert(upperinc == N); |
| assert(lowerinc == 1); |
| } else { |
| assert(upperinc == 1); |
| assert(lowerinc == N); |
| } |
| bool is_lower = uplo == 'L' || uplo == 'l'; |
| for (int i = 0; i < N - 1; i++) { |
| cblas_daxpy(N - i - 1, 1.0, |
| is_lower ? &triv(i, i + 1) : &triv(i + 1, i), |
| is_lower ? upperinc : lowerinc, |
| is_lower ? &triv(i + 1, i) : &triv(i, i + 1), |
| is_lower ? lowerinc : upperinc); |
| } |
| |
| cblas_dlacpy(layout, uplo, N, N, tri, N, dA, lda); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| { |
| |
| std::string Test = "POTRF overwrite A "; |
| init(); |
| |
| ow_potrf(layout, uplo, N, A, lda); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void *)ow_potrf, enzyme_const, layout, enzyme_const, |
| uplo, enzyme_const, N, enzyme_dup, A, dA, |
| enzyme_const, lda); |
| foundCalls = calls; |
| init(); |
| |
| cblas_dpotrf(layout, uplo, N, A, lda, nullptr); |
| double *cacheA = (double *)foundCalls[1].pout_arg1; |
| inputs[5] = BlasInfo(cacheA, (char)layout, N, N, N); |
| assert(inputs[5].ty == ValueType::Matrix); |
| cblas_dlacpy(layout, uplo, N, N, A, lda, cacheA, N); |
| cblas_dscal(1, 0.0, A, lda); |
| |
| inDerivative = true; |
| cblas_dscal(1, 0.0, dA, lda); |
| |
| assert(foundCalls.size() >= 2); |
| assert(foundCalls[4].type == CallType::LACPY); |
| double *tri = (double *)foundCalls[4].pout_arg1; |
| inputs[3] = BlasInfo(tri, (char)layout, N, N, N); |
| |
| cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N); |
| |
| cblas_dtrmm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0, |
| cacheA, N, tri, N); |
| |
| assert(foundCalls.size() >= 5); |
| assert(foundCalls[6].type == CallType::COPY); |
| double *tmp = (double *)foundCalls[6].pout_arg1; |
| inputs[4] = BlasInfo(tmp, N, 1); |
| |
| cblas_dcopy(N, tri, N + 1, tmp, 1); |
| cblas_dscal(N, 0.5, tmp, 1); |
| cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0); |
| cblas_dcopy(N, tmp, 1, tri, N + 1); |
| |
| cblas_dtrsm(layout, uplo_to_rside(uplo), uplo, 'N', 'N', N, N, 1.0, |
| cacheA, N, tri, N); |
| cblas_dtrsm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0, |
| cacheA, N, tri, N); |
| #define triv(r, c) \ |
| tri[(r) * (layout == CblasRowMajor ? N : 1) + \ |
| (c) * (layout == CblasRowMajor ? 1 : N)] |
| |
| int upperinc = (&triv(0, 1) - &triv(0, 0)); |
| int lowerinc = (&triv(1, 0) - &triv(0, 0)); |
| if (layout == CblasColMajor) { |
| assert(upperinc == N); |
| assert(lowerinc == 1); |
| } else { |
| assert(upperinc == 1); |
| assert(lowerinc == N); |
| } |
| bool is_lower = uplo == 'L' || uplo == 'l'; |
| for (int i = 0; i < N - 1; i++) { |
| cblas_daxpy(N - i - 1, 1.0, |
| is_lower ? &triv(i, i + 1) : &triv(i + 1, i), |
| is_lower ? upperinc : lowerinc, |
| is_lower ? &triv(i + 1, i) : &triv(i, i + 1), |
| is_lower ? lowerinc : upperinc); |
| } |
| |
| cblas_dlacpy(layout, uplo, N, N, tri, N, dA, lda); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| } |
| } |
| } |
| |
| static void potrsTests() { |
| int N = 17; |
| int Nrhs = M; |
| // N means normal matrix, T means transposed |
| for (char layout : {CblasColMajor, CblasRowMajor}) { |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| |
| { |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, N, N, lda), |
| /*B*/ BlasInfo(B, layout, N, Nrhs, incB), |
| /*C*/ BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| }; |
| { |
| |
| std::string Test = "POTRS active A, B"; |
| init(); |
| |
| my_potrs(layout, uplo, N, Nrhs, A, lda, B, incB); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::POTRS); |
| assert(calls[0].pout_arg1 == B); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == UNUSED_POINTER); |
| assert(calls[0].farg1 == UNUSED_DOUBLE); |
| assert(calls[0].farg2 == UNUSED_DOUBLE); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == UNUSED_TRANS); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == N); |
| assert(calls[0].iarg2 == Nrhs); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == UNUSED_INT); |
| assert(calls[0].side == UNUSED_TRANS); |
| assert(calls[0].uplo == uplo); |
| assert(calls[0].diag == UNUSED_TRANS); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void *)my_potrs, enzyme_const, layout, enzyme_const, |
| uplo, enzyme_const, N, enzyme_const, Nrhs, enzyme_dup, A, dA, |
| enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| assert(foundCalls[0].type == CallType::LACPY); |
| double *inpB = (double *)foundCalls[0].pout_arg1; |
| inputs[3] = BlasInfo(inpB, layout, N, Nrhs, N); |
| cblas_dlacpy(layout, '\0', N, Nrhs, B, incB, inpB, N); |
| |
| my_potrs(layout, uplo, N, Nrhs, A, lda, B, incB); |
| |
| inDerivative = true; |
| |
| assert(foundCalls[2].type == CallType::SYR2K); |
| double *tri = (double *)foundCalls[2].pout_arg1; |
| inputs[4] = BlasInfo(tri, layout, N, N, N); |
| cblas_dsyr2k(layout, 'U', 'N', N, Nrhs, 1.0, inpB, N, dB, incB, 0.0, |
| tri, N); |
| |
| #define triv(r, c) \ |
| tri[(r) * (layout == CblasRowMajor ? N : 1) + \ |
| (c) * (layout == CblasRowMajor ? 1 : N)] |
| |
| bool is_lower = uplo == 'L' || uplo == 'l'; |
| int upperinc = (&triv(0, 1) - &triv(0,0)); |
| int lowerinc = (&triv(1, 0) - &triv(0,0)); |
| if (layout == CblasColMajor) { |
| assert(upperinc == N); |
| assert(lowerinc == 1); |
| } else { |
| assert(upperinc == 1); |
| assert(lowerinc == N); |
| } |
| for (int i = 0; i < N - 1; i++) { |
| cblas_dcopy(N - i - 1, &triv(i, i + 1), upperinc, &triv(i + 1, i), |
| lowerinc); |
| } |
| |
| cblas_dtrsm(layout, uplo_to_rside(uplo), uplo, 'T', 'N', N, N, 1.0, A, |
| lda, tri, N); |
| |
| cblas_dtrsm(layout, uplo_to_side(uplo), uplo, 'N', 'N', N, N, 1.0, A, |
| lda, tri, N); |
| |
| cblas_dtrsm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0, A, |
| lda, tri, N); |
| |
| |
| #define Av(r, c) \ |
| dA[(r) * (layout == CblasRowMajor ? lda : 1) + \ |
| (c) * (layout == CblasRowMajor ? 1 : lda)] |
| |
| int Aupperinc = (&Av(0, 1) - &Av(0,0)); |
| int Alowerinc = (&Av(1, 0) - &Av(0,0)); |
| if (layout == CblasColMajor) { |
| assert(Aupperinc == lda); |
| assert(Alowerinc == 1); |
| } else { |
| assert(Aupperinc == 1); |
| assert(Alowerinc == lda); |
| } |
| |
| for (int i = 0; i < N; i++) { |
| cblas_daxpy(N - i, -1.0, &triv(i, i), is_lower ? lowerinc : upperinc, |
| &Av(i, i), is_lower ? Alowerinc : Aupperinc); |
| } |
| |
| cblas_dpotrs(layout, uplo, N, Nrhs, A, lda, dB, incB, nullptr); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| { |
| |
| std::string Test = "POTRS active B"; |
| |
| init(); |
| __enzyme_autodiff((void *)my_potrs, enzyme_const, layout, enzyme_const, |
| uplo, enzyme_const, N, enzyme_const, Nrhs, enzyme_const, A, |
| enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| my_potrs(layout, uplo, N, Nrhs, A, lda, B, incB); |
| |
| inDerivative = true; |
| |
| |
| cblas_dpotrs(layout, uplo, N, Nrhs, A, lda, dB, incB, nullptr); |
| |
| checkTest(Test); |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| } |
| } |
| } |
| } |
| |
| static void trtrsTests() { |
| int N = 17; |
| int Nrhs = M; |
| // N means normal matrix, T means transposed |
| for (char layout : {CblasColMajor, CblasRowMajor}) { |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| for (auto diag : {'U', 'u', 'N', 'n'}) |
| for (auto transA : |
| {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) { |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, N, N, lda), |
| /*B*/ BlasInfo(B, layout, N, Nrhs, incB), |
| /*C*/ BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| }; |
| { |
| |
| std::string Test = "TRTRS active A, B"; |
| init(); |
| |
| my_trtrs(layout, uplo, (char)transA, diag, N, Nrhs, A, lda, B, |
| incB); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::TRTRS); |
| assert(calls[0].pout_arg1 == B); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == UNUSED_POINTER); |
| assert(calls[0].farg1 == UNUSED_DOUBLE); |
| assert(calls[0].farg2 == UNUSED_DOUBLE); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == (char)transA); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == N); |
| assert(calls[0].iarg2 == Nrhs); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == UNUSED_INT); |
| assert(calls[0].side == UNUSED_TRANS); |
| assert(calls[0].uplo == uplo); |
| assert(calls[0].diag == diag); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void *)my_trtrs, enzyme_const, layout, |
| enzyme_const, uplo, enzyme_const, (char)transA, |
| enzyme_const, diag, enzyme_const, N, enzyme_const, |
| Nrhs, enzyme_dup, A, dA, enzyme_const, lda, |
| enzyme_dup, B, dB, enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| my_trtrs(layout, uplo, (char)transA, diag, N, Nrhs, A, lda, B, |
| incB); |
| |
| inDerivative = true; |
| |
| cblas_dtrtrs(layout, uplo, (char)transpose(transA), diag, N, Nrhs, |
| A, lda, dB, incB, nullptr); |
| |
| assert(foundCalls[2].type == CallType::LACPY); |
| double *tri = (double *)foundCalls[2].pout_arg1; |
| inputs[3] = BlasInfo(tri, layout, N, N, N); |
| |
| cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N); |
| |
| cblas_dgemm( |
| layout, 'N', 'T', N, N, Nrhs, -1.0, is_normal(transA) ? dB : B, |
| is_normal(transA) ? incB : incB, is_normal(transA) ? B : dB, |
| is_normal(transA) ? incB : incB, 1.0, tri, N); |
| |
| cblas_dcopy((diag == 'U' || diag == 'u') ? N : 0, dA, lda + 1, tri, |
| N + 1); |
| |
| cblas_dlacpy(layout, uplo, N, N, tri, N, dA, lda); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| { |
| |
| std::string Test = "TRTRS active B"; |
| |
| init(); |
| __enzyme_autodiff((void *)my_trtrs, enzyme_const, layout, |
| enzyme_const, uplo, enzyme_const, (char)transA, |
| enzyme_const, diag, enzyme_const, N, enzyme_const, |
| Nrhs, enzyme_const, A, enzyme_const, lda, |
| enzyme_dup, B, dB, enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| my_trtrs(layout, uplo, (char)transA, diag, N, Nrhs, A, lda, B, |
| incB); |
| |
| inDerivative = true; |
| |
| cblas_dtrtrs(layout, uplo, (char)transpose(transA), diag, N, Nrhs, |
| A, lda, dB, incB, nullptr); |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| } |
| { |
| |
| std::string Test = "TRTRS active A"; |
| |
| init(); |
| __enzyme_autodiff((void *)my_trtrs, enzyme_const, layout, |
| enzyme_const, uplo, enzyme_const, (char)transA, |
| enzyme_const, diag, enzyme_const, N, enzyme_const, |
| Nrhs, enzyme_dup, A, dA, enzyme_const, lda, |
| enzyme_const, B, enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| my_trtrs(layout, uplo, (char)transA, diag, N, Nrhs, A, lda, B, |
| incB); |
| |
| inDerivative = true; |
| |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| } |
| { |
| |
| std::string Test = "TRTRS OW active A, B"; |
| |
| init(); |
| __enzyme_autodiff((void *)ow_trtrs, enzyme_const, layout, |
| enzyme_const, uplo, enzyme_const, (char)transA, |
| enzyme_const, diag, enzyme_const, N, enzyme_const, |
| Nrhs, enzyme_dup, A, dA, enzyme_const, lda, |
| enzyme_dup, B, dB, enzyme_const, incB); |
| foundCalls = calls; |
| init(); |
| |
| cblas_dtrtrs(layout, uplo, (char)transA, diag, N, Nrhs, A, lda, B, |
| incB, nullptr); |
| assert(foundCalls[1].type == CallType::LACPY); |
| double *cacheA = (double *)foundCalls[1].pout_arg1; |
| inputs[4] = BlasInfo(cacheA, (char)layout, N, N, N); |
| assert(inputs[4].ty == ValueType::Matrix); |
| cblas_dlacpy(layout, uplo, N, N, A, lda, cacheA, N); |
| |
| assert(foundCalls[2].type == CallType::LACPY); |
| double *cacheB = (double *)foundCalls[2].pout_arg1; |
| inputs[5] = BlasInfo(cacheB, (char)layout, N, Nrhs, N); |
| assert(inputs[5].ty == ValueType::Matrix); |
| cblas_dlacpy(layout, '\0', N, Nrhs, B, incB, cacheB, N); |
| cblas_dscal(1, 0.0, A, lda); |
| |
| inDerivative = true; |
| |
| cblas_dscal(1, 0.0, dA, lda); |
| |
| cblas_dtrtrs(layout, uplo, (char)transpose(transA), diag, N, Nrhs, |
| cacheA, N, dB, incB, nullptr); |
| |
| assert(foundCalls[6].type == CallType::LACPY); |
| double *tri = (double *)foundCalls[6].pout_arg1; |
| inputs[3] = BlasInfo(tri, layout, N, N, N); |
| |
| cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N); |
| |
| cblas_dgemm(layout, 'N', 'T', N, N, Nrhs, -1.0, |
| is_normal(transA) ? dB : cacheB, |
| is_normal(transA) ? incB : N, |
| is_normal(transA) ? cacheB : dB, |
| is_normal(transA) ? N : incB, 1.0, tri, N); |
| |
| cblas_dcopy((diag == 'U' || diag == 'u') ? N : 0, dA, lda + 1, tri, |
| N + 1); |
| |
| cblas_dlacpy(layout, uplo, N, N, tri, N, dA, lda); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| } |
| } |
| } |
| |
| static void symmTests() { |
| int N = 17; |
| int M = 9; |
| // N means normal matrix, T means transposed |
| for (char layout : {CblasColMajor, CblasRowMajor}) { |
| for (auto uplo : {'U', 'u', 'L', 'l'}) |
| for (auto side : {'L', 'l', 'R', 'r'}) { |
| BlasInfo inputs[6] = { |
| /*A*/ BlasInfo(A, layout, is_left(side) ? M : N, is_left(side) ? M : N, lda), |
| /*B*/ BlasInfo(B, layout, M, N, incB), |
| /*C*/ BlasInfo(C, layout, M, N, incC), |
| BlasInfo(), |
| BlasInfo(), |
| BlasInfo(), |
| }; |
| { |
| |
| std::string Test = "SYMM active A, B, C"; |
| init(); |
| |
| my_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| assert(calls.size() == 1); |
| assert(calls[0].inDerivative == false); |
| assert(calls[0].type == CallType::SYMM); |
| assert(calls[0].pout_arg1 == C); |
| assert(calls[0].pin_arg1 == A); |
| assert(calls[0].pin_arg2 == B); |
| assert(calls[0].farg1 == alpha); |
| assert(calls[0].farg2 == beta); |
| assert(calls[0].layout == layout); |
| assert(calls[0].targ1 == UNUSED_TRANS); |
| assert(calls[0].targ2 == UNUSED_TRANS); |
| assert(calls[0].iarg1 == M); |
| assert(calls[0].iarg2 == N); |
| assert(calls[0].iarg3 == UNUSED_INT); |
| assert(calls[0].iarg4 == lda); |
| assert(calls[0].iarg5 == incB); |
| assert(calls[0].iarg6 == incC); |
| assert(calls[0].side == side); |
| assert(calls[0].uplo == uplo); |
| assert(calls[0].diag == UNUSED_TRANS); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void *)my_symm, |
| enzyme_const, layout, |
| enzyme_const, side, |
| enzyme_const, uplo, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| my_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| |
| |
| assert(foundCalls[1].type == CallType::COPY); |
| double *tmp = (double *)foundCalls[1].pout_arg1; |
| cblas_dcopy(is_left(side) ? M : N, dA, lda+1, tmp, 1); |
| inputs[3] = BlasInfo(tmp, is_left(side) ? M : N, 1); |
| |
| // ssyr2k(uplo, 'n', m, n, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) |
| // ssyr2k(uplo,'t', n,m, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) |
| cblas_dsyr2k(layout, |
| uplo, |
| side_to_trans(side), |
| is_left(side) ? M : N, |
| is_left(side) ? N : M, |
| alpha, |
| B, |
| incB, |
| dC, |
| incC, |
| 1.0, |
| dA, |
| lda); |
| |
| cblas_daxpy(is_left(side) ? M : N, -1, dA, lda+1, tmp, 1); |
| cblas_daxpy(is_left(side) ? M : N, 0.5, tmp, 1, dA, lda+1); |
| |
| cblas_dsymm(layout, side, uplo, M, N, alpha, A, lda, dC, incC, 1.0, dB, incB); |
| |
| cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| { |
| |
| std::string Test = "SYMM overwriten active A, B, C"; |
| init(); |
| |
| ow_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| // Check memory of primal on own. |
| checkMemoryTrace(inputs, "Primal " + Test, calls); |
| |
| init(); |
| __enzyme_autodiff((void *)ow_symm, |
| enzyme_const, layout, |
| enzyme_const, side, |
| enzyme_const, uplo, |
| enzyme_const, M, |
| enzyme_const, N, |
| enzyme_const, alpha, |
| enzyme_dup, A, dA, |
| enzyme_const, lda, |
| enzyme_dup, B, dB, |
| enzyme_const, incB, |
| enzyme_const, beta, |
| enzyme_dup, C, dC, |
| enzyme_const, incC); |
| foundCalls = calls; |
| init(); |
| |
| double *cacheA = (double *)foundCalls[0].pout_arg1; |
| inputs[4] = BlasInfo(cacheA, layout, is_left(side) ? M : N, is_left(side) ? M : N, is_left(side) ? M : N); |
| assert(inputs[4].ty == ValueType::Matrix); |
| cblas_dlacpy(layout, '\0', is_left(side) ? M : N, is_left(side) ? M : N, A, lda, cacheA, is_left(side) ? M : N); |
| |
| double *cacheB = (double *)foundCalls[1].pout_arg1; |
| inputs[5] = BlasInfo(cacheB, layout, M, N, M); |
| assert(inputs[5].ty == ValueType::Matrix); |
| cblas_dlacpy(layout, '\0', M, N, B, incB, cacheB, M); |
| |
| ow_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); |
| |
| inDerivative = true; |
| |
| //cblas_dscal(1, 0.0, dA, lda); |
| |
| |
| //assert(foundCalls[1].type == CallType::COPY); |
| double *tmp = (double *)foundCalls[3].pout_arg1; |
| cblas_dcopy(is_left(side) ? M : N, dA, lda+1, tmp, 1); |
| inputs[3] = BlasInfo(tmp, is_left(side) ? M : N, 1); |
| |
| // ssyr2k(uplo, 'n', m, n, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) |
| // ssyr2k(uplo,'t', n,m, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) |
| cblas_dsyr2k(layout, |
| uplo, |
| side_to_trans(side), |
| is_left(side) ? M : N, |
| is_left(side) ? N : M, |
| alpha, |
| cacheB, |
| M, |
| dC, |
| incC, |
| 1.0, |
| dA, |
| lda); |
| |
| cblas_daxpy(is_left(side) ? M : N, -1, dA, lda+1, tmp, 1); |
| cblas_daxpy(is_left(side) ? M : N, 0.5, tmp, 1, dA, lda+1); |
| |
| cblas_dsymm(layout, side, uplo, M, N, alpha, cacheA, is_left(side) ? M : N, dC, incC, 1.0, dB, incB); |
| |
| cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); |
| |
| checkTest(Test); |
| |
| SkipVecIncCheck = true; |
| // Check memory of primal of expected derivative |
| checkMemoryTrace(inputs, "Expected " + Test, calls); |
| |
| // Check memory of primal of our derivative (if equal above, it |
| // should be the same). |
| checkMemoryTrace(inputs, "Found " + Test, foundCalls); |
| SkipVecIncCheck = false; |
| } |
| } |
| } |
| } |
| |
| int main() { |
| /* |
| dotTests(); |
| |
| nrm2Tests(); |
| |
| gemvTests(); |
| |
| gemmTests(); |
| |
| trmvTests(); |
| |
| trmmTests(); |
| |
| syrkTests(); |
| |
| potrfTests(); |
| |
| potrsTests(); |
| |
| trtrsTests(); |
| |
| symmTests(); |
| */ |
| |
| symvTests(); |
| } |