blob: eaab5a8b1e62f1eeb4fc7ca5711dd0d96d529c91 [file] [log] [blame] [edit]
// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
#include "../blasinfra.h"
#include "../test_utils.h"
extern "C" double sqrt(double);
int enzyme_dup;
int enzyme_out;
int enzyme_const;
template <typename RT, typename... T> RT __enzyme_fwddiff(void *, T...);
void my_dgemv(char layout, char trans, int M, int N, double alpha,
double *__restrict__ A, int lda, double *__restrict__ X, int incx,
double beta, double *__restrict__ Y, int incy) {
cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
}
void my_dsymv(char layout, char uplo, int N, double alpha,
double *__restrict__ A, int lda, double *__restrict__ X, int incx,
double beta, double *__restrict__ Y, int incy) {
cblas_dsymv(layout, uplo, N, alpha, A, lda, X, incx, beta, Y, incy);
}
double my_ddot(int N, double *__restrict__ X, int incx, double *__restrict__ Y,
int incy) {
double res = cblas_ddot(N, X, incx, Y, incy);
return res;
}
double my_dnrm2(int N, double *__restrict__ X, int incx) {
double res = cblas_dnrm2(N, X, incx);
return res;
}
void my_dgemm(char layout, char transA, char transB, int M, int N, int K,
double alpha, double *__restrict__ A, int lda,
double *__restrict__ B, int ldb, double beta,
double *__restrict__ C, int ldc) {
cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C,
ldc);
}
void ow_dgemm(char layout, char transA, char transB, int M, int N, int K,
double alpha, double *A, int lda, double *B, int ldb, double beta,
double *C, int ldc) {
cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C,
ldc);
}
void my_dsyrk(char layout, char uplo, char trans,
int N, int K, double alpha,
double *__restrict__ A, int lda, double beta,
double *__restrict__ C, int ldc) {
cblas_dsyrk(layout, uplo, trans, N, K, alpha, A, lda, beta,
C, ldc);
}
void my_potrf(char layout, char uplo, int N, double *__restrict__ A, int lda) {
int info;
cblas_dpotrf(layout, uplo, N, A, lda, nullptr); //&info);
}
static void dotTests() {
{
std::string Test = "DOT active both ";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
my_ddot(N, A, incA, B, incB);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<double>((void *)my_ddot, enzyme_const, N, enzyme_dup, A,
dA, enzyme_const, incA, enzyme_dup, B, dB,
enzyme_const, incB);
foundCalls = calls;
init();
cblas_ddot(N, dA, incA, B, incB);
cblas_ddot(N, A, incA, dB, incB);
my_ddot(N, A, incA, B, incB);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
{
std::string Test = "DOT active A ";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
my_ddot(N, A, incA, B, incB);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<double>((void *)my_ddot, enzyme_const, N, enzyme_dup, A,
dA, enzyme_const, incA, enzyme_const, B,
enzyme_const, incB);
foundCalls = calls;
init();
cblas_ddot(N, dA, incA, B, incB);
my_ddot(N, A, incA, B, incB);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
{
std::string Test = "DOT active B ";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
my_ddot(N, A, incA, B, incB);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<double>((void *)my_ddot, enzyme_const, N, enzyme_const, A,
enzyme_const, incA, enzyme_dup, B, dB,
enzyme_const, incB);
foundCalls = calls;
init();
cblas_ddot(N, A, incA, dB, incB);
my_ddot(N, A, incA, B, incB);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
{
std::string Test = "DOT const both";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
my_ddot(N, A, incA, B, incB);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
double dres = __enzyme_fwddiff<double>((void *)my_ddot, enzyme_const, N,
enzyme_const, A, enzyme_const, incA,
enzyme_const, B, enzyme_const, incB);
foundCalls = calls;
init();
my_ddot(N, A, incA, B, incB);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
APPROX_EQ(dres, 0.0, 1e-10);
}
}
static void nrm2Tests() {
{
std::string Test = "NRM2 active";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
my_dnrm2(N, A, incA);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
double ADres =
__enzyme_fwddiff<double>((void *)my_dnrm2, enzyme_const, N, enzyme_dup,
A, dA, enzyme_const, incA);
foundCalls = calls;
init();
double dres = cblas_ddot(N, A, incA, dA, incA);
double nres = my_dnrm2(N, A, incA);
double trueRes = sqrt(dres) / nres;
my_dnrm2(N, A, incA);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
APPROX_EQ(ADres, trueRes, 1e-10);
}
{
std::string Test = "NRM2 const";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
my_dnrm2(N, A, incA);
my_ddot(N, A, incA, B, incB);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
double dres = __enzyme_fwddiff<double>((void *)my_dnrm2, enzyme_const, N,
enzyme_const, A, enzyme_const, incA);
foundCalls = calls;
init();
my_dnrm2(N, A, incA);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
APPROX_EQ(dres, 0.0, 1e-10);
}
}
static void gemvTests() {
// N means normal matrix, T means transposed
for (char layout : {CblasRowMajor, CblasColMajor}) {
for (auto transA :
{CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) {
// todo in fortran blas consider 'N', 'n', 'T', 't'}
{
bool trans = !is_normal(transA);
std::string Test = "GEMV active A, C ";
BlasInfo inputs[6] = {/*A*/ BlasInfo(A, layout, M, N, lda),
/*B*/ BlasInfo(B, trans ? M : N, incB),
/*C*/ BlasInfo(C, trans ? N : M, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()};
init();
my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C,
incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::GEMV);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].layout == layout);
assert(calls[0].targ1 == (char)transA);
assert(calls[0].targ2 == UNUSED_TRANS);
assert(calls[0].iarg1 == M);
assert(calls[0].iarg2 == N);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<void>(
(void *)my_dgemv, enzyme_const, layout, enzyme_const, transA,
enzyme_const, M, enzyme_const, N, enzyme_const, alpha, enzyme_dup,
A, dA, enzyme_const, lda, enzyme_const, B, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dgemv(layout, (char)transA, M, N, alpha, dA, lda, B, incB, beta,
dC, incC);
my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C,
incC);
// cblas_dscal(trans ? N : M, beta, dC, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "GEMV active A, B, C ";
init();
__enzyme_fwddiff<void>(
(void *)my_dgemv, enzyme_const, layout, enzyme_const, transA,
enzyme_const, M, enzyme_const, N, enzyme_const, alpha, enzyme_dup,
A, dA, enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dgemv(layout, (char)transA, M, N, alpha, A, lda, dB, incB, beta,
dC, incC);
cblas_dgemv(layout, (char)transA, M, N, alpha, dA, lda, B, incB, 1.0, dC, incC);
// cblas_dscal(trans ? N : M, beta, dC, incC);
my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C,
incC);
// NOT ACTIVE: cblas_dgemv(layout, trans, M, N, dalpha, A, lda, B,
// incB, 1.0, C, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "GEMV active B, C ";
init();
__enzyme_fwddiff<void>(
(void *)my_dgemv, enzyme_const, layout, enzyme_const, transA,
enzyme_const, M, enzyme_const, N, enzyme_const, alpha, enzyme_const,
A, enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dgemv(layout, (char)transA, M, N, alpha, A, lda, dB, incB, beta,
dC, incC);
// cblas_dscal(trans ? N : M, beta, dC, incC);
my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C,
incC);
// NOT ACTIVE: cblas_dgemv(layout, trans, M, N, dalpha, A, lda, B,
// incB, 1.0, C, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}
static void symvTests() {
// N means normal matrix, T means transposed
for (char layout : {CblasRowMajor, CblasColMajor}) {
for (auto uplo : {'U', 'u', 'L', 'l'}) {
{
std::string Test = "SYMV active A, C ";
BlasInfo inputs[6] = {/*A*/ BlasInfo(A, layout, N, N, lda),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, N, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()};
init();
my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::SYMV);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].layout == layout);
assert(calls[0].uplo == uplo);
assert(calls[0].targ1 == UNUSED_TRANS);
assert(calls[0].targ2 == UNUSED_TRANS);
assert(calls[0].iarg1 == N);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<void>(
(void *)my_dsymv, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, N, enzyme_const, alpha, enzyme_dup,
A, dA, enzyme_const, lda, enzyme_const, B, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dsymv(layout, uplo, N, alpha, dA, lda, B, incB, beta,
dC, incC);
my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);
// cblas_dscal(trans ? N : M, beta, dC, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "SYMV active A, B, C ";
init();
__enzyme_fwddiff<void>(
(void *)my_dsymv, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, N, enzyme_const, alpha, enzyme_dup,
A, dA, enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dsymv(layout, uplo, N, alpha, A, lda, dB, incB, beta,
dC, incC);
cblas_dsymv(layout, uplo, N, alpha, dA, lda, B, incB, 1.0, dC, incC);
// cblas_dscal(trans ? N : M, beta, dC, incC);
my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);
// NOT ACTIVE: cblas_dgemv(layout, trans, M, N, dalpha, A, lda, B,
// incB, 1.0, C, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "SYMV active B, C ";
init();
__enzyme_fwddiff<void>(
(void *)my_dsymv, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, N, enzyme_const, alpha, enzyme_const,
A, enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dsymv(layout, uplo, N, alpha, A, lda, dB, incB, beta,
dC, incC);
// cblas_dscal(trans ? N : M, beta, dC, incC);
my_dsymv(layout, uplo, N, alpha, A, lda, B, incB, beta, C,
incC);
// NOT ACTIVE: cblas_dgemv(layout, trans, M, N, dalpha, A, lda, B,
// incB, 1.0, C, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}
static void gemmTests() {
// N means normal matrix, T means transposed
for (char layout : {CblasRowMajor, CblasColMajor}) {
for (auto transA :
{CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) {
for (auto transB :
{CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) {
// todo fortran blas {'N', 'n', 'T', 't'}
{
bool transA_bool = !is_normal(transA);
bool transB_bool = !is_normal(transB);
std::string Test = "GEMM Active A, B, C";
BlasInfo inputs[6] = {/*A*/ BlasInfo(A, layout, transA_bool ? K : M,
transA_bool ? M : K, lda),
/*B*/
BlasInfo(B, layout, transB_bool ? N : K,
transB_bool ? K : N, incB),
/*C*/ BlasInfo(C, layout, M, N, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()};
init();
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda,
B, incB, beta, C, incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::GEMM);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].layout == layout);
assert(calls[0].targ1 == (char)transA);
assert(calls[0].targ2 == (char)transB);
assert(calls[0].iarg1 == M);
assert(calls[0].iarg2 == N);
assert(calls[0].iarg3 == K);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<void>(
(void *)my_dgemm, enzyme_const, layout, enzyme_const, transA,
enzyme_const, transB, enzyme_const, M, enzyme_const, N,
enzyme_const, K, enzyme_const, alpha, enzyme_dup, A, dA,
enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda,
dB, incB, beta, dC, incC);
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, dA, lda,
B, incB, 1.0, dC, incC);
// NOT ACTIVE: my_dgemm(layout, (char)transA, (char)transB, M, N, K,
// dalpha, A, lda, B, incB, 1.0, C, incC);
// cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0);
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda,
B, incB, beta, C, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "GEMM Active A, C";
init();
__enzyme_fwddiff<void>(
(void *)my_dgemm, enzyme_const, layout, enzyme_const, transA,
enzyme_const, transB, enzyme_const, M, enzyme_const, N,
enzyme_const, K, enzyme_const, alpha, enzyme_dup, A, dA,
enzyme_const, lda, enzyme_const, B, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, dA, lda,
B, incB, beta, dC, incC);
// NOT ACTIVE: my_dgemm(layout, (char)transA, (char)transB, M, N, K,
// dalpha, A, lda, B, incB, 1.0, C, incC);
// cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0);
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda,
B, incB, beta, C, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "GEMM Active B, C";
init();
__enzyme_fwddiff<void>(
(void *)my_dgemm, enzyme_const, layout, enzyme_const, transA,
enzyme_const, transB, enzyme_const, M, enzyme_const, N,
enzyme_const, K, enzyme_const, alpha, enzyme_const, A,
enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda,
dB, incB, beta, dC, incC);
// NOT ACTIVE: my_dgemm(layout, (char)transA, (char)transB, M, N, K,
// dalpha, A, lda, B, incB, 1.0, C, incC);
// cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0);
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda,
B, incB, beta, C, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}
}
static void syrkTests() {
// N means normal matrix, T means transposed
for (char layout : {CblasColMajor, CblasRowMajor}) {
for (auto transA :
{CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans})
for (auto uplo : {'U', 'u', 'L', 'l'})
{
bool trans = !is_normal(transA);
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, layout, trans ? K : N, trans ? N : K, lda),
/*B*/ BlasInfo(),
/*C*/ BlasInfo(C, layout, N, N, incC),
BlasInfo(),
BlasInfo(),
BlasInfo(),
};
{
std::string Test = "SYRK active C, A ";
init();
my_dsyrk(layout, uplo, (char)transA, N, K, alpha, A, lda, beta, C,
incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::SYRK);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == UNUSED_POINTER);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].layout == layout);
assert(calls[0].targ1 == (char)transA);
assert(calls[0].targ2 == UNUSED_TRANS);
assert(calls[0].iarg1 == N);
assert(calls[0].iarg2 == K);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incC);
assert(calls[0].iarg6 == UNUSED_INT);
assert(calls[0].side == UNUSED_TRANS);
assert(calls[0].uplo == uplo);
assert(calls[0].diag == UNUSED_TRANS);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<void>(
(void *)my_dsyrk, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, transA, enzyme_const, N, enzyme_const, K,
enzyme_const, alpha, enzyme_dup, A, dA, enzyme_const, lda,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dsyr2k(layout, uplo, (char)transA, N, K, alpha, A, lda, dA, lda, beta, dC, incC);
my_dsyrk(layout, uplo, (char)transA, N, K, alpha, A, lda, beta, C,
incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
{
std::string Test = "SYRK active C ";
init();
__enzyme_fwddiff<void>(
(void *)my_dsyrk, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, transA, enzyme_const, N, enzyme_const, K,
enzyme_const, alpha, enzyme_const, A, enzyme_const, lda,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dlascl(layout, uplo, 0, 0, 1.0, beta, N, N, dC, incC, 0);
my_dsyrk(layout, uplo, (char)transA, N, K, alpha, A, lda, beta, C,
incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
{
std::string Test = "SYRK active C, alpha ";
init();
double dalpha = 46.1345;
__enzyme_fwddiff<void>(
(void *)my_dsyrk, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, transA, enzyme_const, N, enzyme_const, K,
enzyme_dup, alpha, dalpha, enzyme_dup, A, dA, enzyme_const, lda,
enzyme_const, beta, enzyme_dup, C, dC, enzyme_const, incC);
foundCalls = calls;
init();
cblas_dsyr2k(layout, uplo, (char)transA, N, K, alpha, A, lda, dA, lda, beta, dC, incC);
cblas_dsyrk(layout, uplo, (char)transA, N, K, dalpha, A, lda, 1.0, dC, incC);
my_dsyrk(layout, uplo, (char)transA, N, K, alpha, A, lda, beta, C,
incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}
static void potrfTests() {
int N = 17;
// N means normal matrix, T means transposed
for (char layout : {CblasColMajor, CblasRowMajor}) {
for (auto uplo : {'U', 'u', 'L', 'l'})
{
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, layout, N, N, lda),
/*B*/ BlasInfo(),
/*C*/ BlasInfo(),
BlasInfo(),
BlasInfo(),
BlasInfo(),
};
{
std::string Test = "POTRF active A ";
init();
my_potrf(layout, uplo, N, A, lda);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::POTRF);
assert(calls[0].pout_arg1 == A);
assert(calls[0].pin_arg1 == UNUSED_POINTER);
assert(calls[0].pin_arg2 == UNUSED_POINTER);
assert(calls[0].farg1 == UNUSED_DOUBLE);
assert(calls[0].farg2 == UNUSED_DOUBLE);
assert(calls[0].layout == layout);
assert(calls[0].targ1 == UNUSED_TRANS);
assert(calls[0].targ2 == UNUSED_TRANS);
assert(calls[0].iarg1 == N);
assert(calls[0].iarg2 == UNUSED_INT);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == UNUSED_INT);
assert(calls[0].iarg6 == UNUSED_INT);
assert(calls[0].side == UNUSED_TRANS);
assert(calls[0].uplo == uplo);
assert(calls[0].diag == UNUSED_TRANS);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_fwddiff<void>(
(void *)my_potrf, enzyme_const, layout, enzyme_const, uplo,
enzyme_const, N, enzyme_dup, A, dA, enzyme_const, lda);
foundCalls = calls;
init();
my_potrf(layout, uplo, N, A, lda);
assert(foundCalls.size() >= 2);
assert(foundCalls[1].type == CallType::LACPY);
double* tri = (double*)foundCalls[1].pout_arg1;
inputs[3] = BlasInfo(tri, layout, N, N, N);
cblas_dlacpy(layout, flip_uplo(uplo), N, N, dA, lda, tri, N);
#define dAv(r, c) \
dA[(r) * (layout == CblasRowMajor ? lda : 1) + \
(c) * (layout == CblasRowMajor ? 1 : lda)]
int upperinc = (&dAv(0, 1) - &dAv(0, 0));
int lowerinc = (&dAv(1, 0) - &dAv(0, 0));
if (layout == CblasColMajor) {
assert(upperinc == lda);
assert(lowerinc == 1);
} else {
assert(upperinc == 1);
assert(lowerinc == lda);
}
bool is_lower = uplo == 'L' || uplo == 'l';
for (int i = 0; i < N - 1; i++) {
cblas_dcopy(N - i - 1,
is_lower ? (&dAv(i + 1, i)) : (&dAv(i, i + 1)),
is_lower ? lowerinc : upperinc,
is_lower ? (&dAv(i, i + 1)) : (&dAv(i + 1, i)),
is_lower ? upperinc : lowerinc);
}
cblas_dtrsm(layout, 'L', uplo, uplo_to_normal(uplo), 'N', N, N, 1.0, A, lda, dA, lda);
cblas_dtrsm(layout, 'R', uplo, uplo_to_trans(uplo), 'N', N, N, 1.0, A, lda, dA, lda);
cblas_dscal(N, 0.5, dA, lda+1);
assert(foundCalls.size() >= 9);
assert(foundCalls[21].type == CallType::COPY);
double *tmp = (double *)foundCalls[21].pout_arg1;
inputs[4] = BlasInfo(tmp, N, 1);
cblas_dcopy(N, dA, lda+1, tmp, 1);
cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, dA, lda, 0);
cblas_dcopy(N, tmp, 1, dA, lda+1);
cblas_dtrmm(layout, uplo_to_side(uplo), uplo, 'N', 'N', N, N, 1.0, A, lda, dA, lda);
cblas_dcopy(N, dA, lda+1, tmp, 1);
cblas_dlacpy(layout, flip_uplo(uplo), N, N, tri, N, dA, lda);
cblas_dcopy(N, tmp, 1, dA, lda+1);
checkTest(Test);
SkipVecIncCheck = true;
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
SkipVecIncCheck = false;
}
}
}
}
int main() {
dotTests();
nrm2Tests();
gemvTests();
gemmTests();
syrkTests();
potrfTests();
symvTests();
}