blob: 8c929ed054d0816a25b5d4283f17220207cfb3e8 [file] [log] [blame] [edit]
// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -
// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -
// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli -
// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -
// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -
// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli -
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
#include "../blasinfra.h"
int enzyme_dup;
int enzyme_out;
int enzyme_const;
template<typename ...T>
void __enzyme_autodiff(void*, T...);
void my_dgemv(char layout, char trans, int M, int N, double alpha, double* __restrict__ A, int lda, double* __restrict__ X, int incx, double beta, double* __restrict__ Y, int incy) {
cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
inDerivative = true;
}
void ow_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) {
cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
inDerivative = true;
}
double my_ddot(int N, double* __restrict__ X, int incx, double* __restrict__ Y, int incy) {
double res = cblas_ddot(N, X, incx, Y, incy);
inDerivative = true;
return res;
}
void my_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* __restrict__ A, int lda, double* __restrict__ B, int ldb, double beta, double* __restrict__ C, int ldc) {
cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
inDerivative = true;
}
static void dotTests() {
std::string Test = "DOT active both ";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC),
BlasInfo(),
BlasInfo(),
BlasInfo(),
};
init();
my_ddot(N, A, incA, B, incB);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void*) my_ddot,
enzyme_const, N,
enzyme_dup, A, dA,
enzyme_const, incA,
enzyme_dup, B, dB,
enzyme_const, incB);
foundCalls = calls;
init();
my_ddot(N, A, incA, B, incB);
inDerivative = true;
cblas_daxpy(N, 1.0, B, incB, dA, incA);
cblas_daxpy(N, 1.0, A, incA, dB, incB);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
static void gemvTests() {
// N means normal matrix, T means transposed
for (char layout : { CblasRowMajor, CblasColMajor }) {
for (auto transA : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) {
// todo in fortran blas consider 'N', 'n', 'T', 't'}
{
bool trans = !is_normal(transA);
std::string Test = "GEMV active A, C ";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, layout, M, N, lda),
/*B*/ BlasInfo(B, trans ? M : N, incB),
/*C*/ BlasInfo(C, trans ? N : M, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()
};
init();
my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::GEMV);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].layout == layout);
assert(calls[0].targ1 == (char)transA);
assert(calls[0].targ2 == UNUSED_TRANS);
assert(calls[0].iarg1 == M);
assert(calls[0].iarg2 == N);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void*) my_dgemv,
enzyme_const, layout,
enzyme_const, transA,
enzyme_const, M,
enzyme_const, N,
enzyme_const, alpha,
enzyme_dup, A, dA,
enzyme_const, lda,
enzyme_const, B,
enzyme_const, incB,
enzyme_const, beta,
enzyme_dup, C, dC,
enzyme_const, incC);
foundCalls = calls;
init();
my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC);
inDerivative = true;
// dC = alpha * X * transpose(Y) + A
cblas_dger(layout, M, N, alpha, trans ? B : dC, trans ? incB : incC, trans ? dC : B, trans ? incC : incB, dA, lda);
// dY = beta * dY
cblas_dscal(trans ? N : M, beta, dC, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "GEMV active A, B, C ";
init();
__enzyme_autodiff((void*) my_dgemv,
enzyme_const, layout,
enzyme_const, transA,
enzyme_const, M,
enzyme_const, N,
enzyme_const, alpha,
enzyme_dup, A, dA,
enzyme_const, lda,
enzyme_dup, B, dB,
enzyme_const, incB,
enzyme_const, beta,
enzyme_dup, C, dC,
enzyme_const, incC);
foundCalls = calls;
init();
my_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC);
inDerivative = true;
// dC = alpha * X * transpose(Y) + A
cblas_dger(layout, M, N, alpha, trans ? B : dC, trans ? incB : incC, trans ? dC : B, trans ? incC : incB, dA, lda);
// dB = alpha * trans(A) * dC + dB
cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A, lda, dC, incC, 1.0, dB, incB);
// dY = beta * dY
cblas_dscal(trans ? N : M, beta, dC, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
Test = "GEMV active/overwrite";
init();
__enzyme_autodiff((void*) ow_dgemv,
enzyme_const, layout,
enzyme_const, transA,
enzyme_const, M,
enzyme_const, N,
enzyme_const, alpha,
enzyme_dup, A, dA,
enzyme_const, lda,
enzyme_dup, B, dB,
enzyme_const, incB,
enzyme_const, beta,
enzyme_dup, C, dC,
enzyme_const, incC);
foundCalls = calls;
init();
assert(foundCalls.size() > 2);
auto A_cache = (double*)foundCalls[0].pout_arg1;
cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, N);
inputs[4] = BlasInfo(A_cache, layout, M, N, N);
auto B_cache = (double*)foundCalls[1].pout_arg1;
cblas_dcopy(trans ? M : N, B, incB, B_cache, 1);
inputs[5] = BlasInfo(B_cache, trans ? M : N, 1);
ow_dgemv(layout, (char)transA, M, N, alpha, A, lda, B, incB, beta, C, incC);
inDerivative = true;
// dC = alpha * X * transpose(Y) + A
cblas_dger(layout, M, N, alpha,
trans ? B_cache : dC,
trans ? 1 : incC,
trans ? dC : B_cache,
trans ? incC : 1, dA,
lda);
// dB = alpha * trans(A) * dC + dB
cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, N, dC, incC, 1.0, dB, incB);
// dY = beta * dY
cblas_dscal(trans ? N : M, beta, dC, incC);
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
inputs[4] = BlasInfo();
inputs[5] = BlasInfo();
}
}
}
}
static void gemmTests() {
// N means normal matrix, T means transposed
for (char layout : { CblasRowMajor, CblasColMajor }) {
for (auto transA : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) {
for (auto transB : {CBLAS_TRANSPOSE::CblasNoTrans, CBLAS_TRANSPOSE::CblasTrans}) {
// todo fortran blas {'N', 'n', 'T', 't'}
{
bool transA_bool = !is_normal(transA);
bool transB_bool = !is_normal(transB);
std::string Test = "GEMM";
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, layout, transA_bool ? K : M, transA_bool ? M : K, lda),
/*B*/ BlasInfo(B, layout, transB_bool ? N : K , transB_bool ? K : N, incB),
/*C*/ BlasInfo(C, layout, M, N, incC),
BlasInfo(),
BlasInfo(),
BlasInfo()
};
init();
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC);
assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::GEMM);
assert(calls[0].pout_arg1 == C);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == B);
assert(calls[0].farg1 == alpha);
assert(calls[0].farg2 == beta);
assert(calls[0].layout == layout);
assert(calls[0].targ1 == (char)transA);
assert(calls[0].targ2 == (char)transB);
assert(calls[0].iarg1 == M);
assert(calls[0].iarg2 == N);
assert(calls[0].iarg3 == K);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == incC);
// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);
init();
__enzyme_autodiff((void*) my_dgemm,
enzyme_const, layout,
enzyme_const, transA,
enzyme_const, transB,
enzyme_const, M,
enzyme_const, N,
enzyme_const, K,
enzyme_const, alpha,
enzyme_dup, A, dA,
enzyme_const, lda,
enzyme_dup, B, dB,
enzyme_const, incB,
enzyme_const, beta,
enzyme_dup, C, dC,
enzyme_const, incC);
foundCalls = calls;
init();
my_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC);
inDerivative = true;
// dA =
my_dgemm(layout,
transA_bool ? (char)transB : (char)CBLAS_TRANSPOSE::CblasNoTrans,
transA_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transB),
transA_bool ? K : M,
transA_bool ? M : K,
N,
alpha,
transA_bool ? B : dC,
transA_bool ? incB : incC,
transA_bool ? dC : B,
transA_bool ? incC : incB,
1.0, dA, lda);
// dB =
my_dgemm(layout,
transB_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transA),
transB_bool ? (char)transA : (char)CBLAS_TRANSPOSE::CblasNoTrans, //transB,
transB_bool ? N : K,
transB_bool ? K : N,
M,
alpha,
transB_bool ? dC : A,
transB_bool ? incC : lda,
transB_bool ? A : dC,
transB_bool ? lda : incC,
1.0, dB, incB);
cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 );
checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);
// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}
}
int main() {
dotTests();
gemvTests();
gemmTests();
}