blob: 8d9630fbdb3ae414d9cf8a2215f50358260a4e51 [file] [edit]
//===- MemRefAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR memref dialect.
//
//===----------------------------------------------------------------------===//
#include "Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"
using namespace mlir;
using namespace mlir::enzyme;
namespace {
#include "Implementations/MemRefDerivatives.inc"
struct LoadOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse,
memref::LoadOp> {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto loadOp = cast<memref::LoadOp>(op);
Value memref = loadOp.getMemref();
if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
if (!gutils->isConstantValue(loadOp) &&
!gutils->isConstantValue(memref)) {
Value gradient = gutils->diffe(loadOp, builder);
Value memrefGradient = gutils->popCache(caches.front(), builder);
SmallVector<Value> retrievedArguments;
for (Value cache : ValueRange(caches).drop_front(1)) {
Value retrievedValue = gutils->popCache(cache, builder);
retrievedArguments.push_back(retrievedValue);
}
if (!gutils->AtomicAdd) {
Value loadedGradient =
memref::LoadOp::create(builder, loadOp.getLoc(), memrefGradient,
ArrayRef<Value>(retrievedArguments));
Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(),
loadedGradient, gradient);
memref::StoreOp::create(builder, loadOp.getLoc(), addedGradient,
memrefGradient,
ArrayRef<Value>(retrievedArguments));
} else {
memref::AtomicRMWOp::create(
builder, loadOp.getLoc(), arith::AtomicRMWKind::addf, gradient,
memrefGradient, ArrayRef<Value>(retrievedArguments));
}
}
}
return success();
}
SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
auto loadOp = cast<memref::LoadOp>(op);
Value memref = loadOp.getMemref();
ValueRange indices = loadOp.getIndices();
if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
if (!gutils->isConstantValue(loadOp) &&
!gutils->isConstantValue(memref)) {
OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
SmallVector<Value> caches;
caches.push_back(gutils->initAndPushCache(
gutils->invertPointerM(memref, cacheBuilder), cacheBuilder));
for (Value v : indices) {
caches.push_back(gutils->initAndPushCache(
gutils->getNewFromOriginal(v), cacheBuilder));
}
return caches;
}
}
return SmallVector<Value>();
}
void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {
// auto loadOp = cast<memref::LoadOp>(op);
// Value memref = loadOp.getMemref();
// Value shadow = gutils->getShadowValue(memref);
// Do nothing yet. In the future support memref<memref<...>>
}
};
struct StoreOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse,
memref::StoreOp> {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto storeOp = cast<memref::StoreOp>(op);
Value val = storeOp.getValue();
Value memref = storeOp.getMemref();
// ValueRange indices = storeOp.getIndices();
auto iface = cast<AutoDiffTypeInterface>(val.getType());
if (!gutils->isConstantValue(memref)) {
Value memrefGradient = gutils->popCache(caches.front(), builder);
SmallVector<Value> retrievedArguments;
for (Value cache : ValueRange(caches).drop_front(1)) {
Value retrievedValue = gutils->popCache(cache, builder);
retrievedArguments.push_back(retrievedValue);
}
if (!iface.isMutable()) {
if (!gutils->isConstantValue(val)) {
Value loadedGradient =
memref::LoadOp::create(builder, storeOp.getLoc(), memrefGradient,
ArrayRef<Value>(retrievedArguments));
gutils->addToDiffe(val, loadedGradient, builder);
}
auto zero =
cast<AutoDiffTypeInterface>(gutils->getShadowType(val.getType()))
.createNullValue(builder, op->getLoc());
memref::StoreOp::create(builder, storeOp.getLoc(), zero, memrefGradient,
ArrayRef<Value>(retrievedArguments));
}
}
return success();
}
SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
auto storeOp = cast<memref::StoreOp>(op);
Value memref = storeOp.getMemref();
ValueRange indices = storeOp.getIndices();
Value val = storeOp.getValue();
if (auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType())) {
if (!gutils->isConstantValue(memref)) {
OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
SmallVector<Value> caches;
caches.push_back(gutils->initAndPushCache(
gutils->invertPointerM(memref, cacheBuilder), cacheBuilder));
for (Value v : indices) {
caches.push_back(gutils->initAndPushCache(
gutils->getNewFromOriginal(v), cacheBuilder));
}
return caches;
}
}
return SmallVector<Value>();
}
void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {
// auto storeOp = cast<memref::StoreOp>(op);
// Value memref = storeOp.getMemref();
// Value shadow = gutils->getShadowValue(memref);
// Do nothing yet. In the future support memref<memref<...>>
}
};
struct SubViewOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<
SubViewOpInterfaceReverse, memref::SubViewOp> {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
return success();
}
SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
return SmallVector<Value>();
}
void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {
auto subviewOp = cast<memref::SubViewOp>(op);
auto newSubviewOp = cast<memref::SubViewOp>(gutils->getNewFromOriginal(op));
if (!gutils->isConstantValue(subviewOp.getSource())) {
Value shadow = memref::SubViewOp::create(
builder, op->getLoc(), newSubviewOp.getType(),
gutils->invertPointerM(subviewOp.getSource(), builder),
newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(),
newSubviewOp.getMixedStrides());
gutils->setInvertedPointer(subviewOp, shadow);
}
}
};
struct AllocaScopeOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<
AllocaScopeOpInterfaceReverse, memref::AllocaScopeOp> {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto scopeOp = cast<memref::AllocaScopeOp>(op);
SmallVector<bool> resultsActive(scopeOp->getNumResults(), false);
SmallVector<Value> incomingGradients;
incomingGradients.reserve(scopeOp->getNumResults());
for (OpResult result : scopeOp->getResults()) {
bool active = !gutils->isConstantValue(result);
resultsActive[result.getResultNumber()] = active;
if (active) {
incomingGradients.push_back(gutils->diffe(result, builder));
gutils->zeroDiffe(result, builder);
}
}
Region &scopeRegion = scopeOp.getBodyRegion();
SmallVector<Value> capturedInputs;
scopeRegion.walk([&](Operation *inner) {
for (Value operand : inner->getOperands()) {
Region *defRegion = operand.getParentRegion();
if (!defRegion || scopeRegion.isAncestor(defRegion))
continue;
if (gutils->isConstantValue(operand))
continue;
auto iface = dyn_cast<AutoDiffTypeInterface>(operand.getType());
if (!iface || iface.isMutable())
continue;
if (!llvm::is_contained(capturedInputs, operand))
capturedInputs.push_back(operand);
}
});
SmallVector<Value> capturedPre;
capturedPre.reserve(capturedInputs.size());
for (Value v : capturedInputs) {
capturedPre.push_back(gutils->diffe(v, builder));
gutils->zeroDiffe(v, builder);
}
auto newScope = cast<memref::AllocaScopeOp>(gutils->getNewFromOriginal(op));
newScope->moveBefore(builder.getInsertionBlock(),
builder.getInsertionPoint());
Block &newBody = newScope.getBodyRegion().front();
OpBuilder bodyBuilder(newBody.getTerminator());
Block &oldBody = scopeOp.getBodyRegion().front();
bool valid = true;
for (Operation &innerOp : oldBody.getOperations()) {
for (Value result : innerOp.getResults()) {
if (!gutils->isConstantValue(result)) {
auto iface = dyn_cast<AutoDiffTypeInterface>(result.getType());
if (iface && !iface.isMutable())
gutils->zeroDiffe(result, bodyBuilder);
}
}
}
Operation *term = oldBody.getTerminator();
unsigned incomingIdx = 0;
for (auto indexedOperand : llvm::enumerate(term->getOperands())) {
if (!resultsActive[indexedOperand.index()])
continue;
Value operand = indexedOperand.value();
if (!gutils->isConstantValue(operand))
gutils->addToDiffe(operand, incomingGradients[incomingIdx],
bodyBuilder);
incomingIdx++;
}
auto first = oldBody.rbegin();
first++;
for (auto it = first; it != oldBody.rend(); it++) {
valid &= gutils->Logic.visitChild(&*it, bodyBuilder, gutils).succeeded();
}
if (capturedInputs.empty())
return success(valid);
SmallVector<Value> contributions;
contributions.reserve(capturedInputs.size());
for (Value v : capturedInputs)
contributions.push_back(gutils->diffe(v, bodyBuilder));
unsigned numPrimal = newScope->getNumResults();
Operation *bodyTerm = newBody.getTerminator();
SmallVector<Value> retVals(bodyTerm->getOperands().begin(),
bodyTerm->getOperands().end());
retVals.append(contributions.begin(), contributions.end());
SmallVector<Type> newResultTypes(newScope->getResultTypes().begin(),
newScope->getResultTypes().end());
for (Value v : capturedInputs)
newResultTypes.push_back(gutils->getShadowType(v.getType()));
OpBuilder scopeBuilder(newScope);
auto extScope = memref::AllocaScopeOp::create(scopeBuilder, op->getLoc(),
newResultTypes);
extScope.getBodyRegion().takeBody(newScope.getBodyRegion());
Block &extBody = extScope.getBodyRegion().front();
Operation *movedTerm = extBody.getTerminator();
OpBuilder termBuilder(movedTerm);
memref::AllocaScopeReturnOp::create(termBuilder, movedTerm->getLoc(),
retVals);
gutils->erase(movedTerm);
for (unsigned i = 0; i < numPrimal; ++i) {
newScope->getResult(i).replaceAllUsesWith(extScope.getResult(i));
gutils->originalToNewFn.map(scopeOp->getResult(i), extScope.getResult(i));
}
gutils->erase(newScope);
builder.setInsertionPointAfter(extScope);
for (auto indexed : llvm::enumerate(capturedInputs)) {
Value v = indexed.value();
gutils->setDiffe(v, capturedPre[indexed.index()], builder);
gutils->addToDiffe(v, extScope.getResult(numPrimal + indexed.index()),
builder);
}
return success(valid);
}
SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
return SmallVector<Value>();
}
void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}
};
class MemRefClonableTypeInterface
: public ClonableTypeInterface::ExternalModel<MemRefClonableTypeInterface,
MemRefType> {
public:
mlir::Value cloneValue(mlir::Type self, OpBuilder &builder,
Value value) const {
MemRefType MT = cast<MemRefType>(self);
SmallVector<Value> dynamicSizes;
for (auto [i, s] : llvm::enumerate(MT.getShape())) {
if (s == ShapedType::kDynamic) {
Value dim = arith::ConstantIndexOp::create(builder, value.getLoc(), i);
dynamicSizes.push_back(
memref::DimOp::create(builder, value.getLoc(), value, dim));
}
}
auto clone =
memref::AllocOp::create(builder, value.getLoc(), self, dynamicSizes);
memref::CopyOp::create(builder, value.getLoc(), value, clone);
return clone;
}
void freeClonedValue(mlir::Type self, OpBuilder &builder, Value value) const {
memref::DeallocOp::create(builder, value.getLoc(), value);
};
};
class MemRefAutoDiffTypeInterface
: public AutoDiffTypeInterface::ExternalModel<MemRefAutoDiffTypeInterface,
MemRefType> {
public:
mlir::Attribute createNullAttr(mlir::Type self) const {
llvm_unreachable("Cannot create null of memref (todo polygeist null)");
}
mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
Location loc) const {
// Create a memref of the given type with the required number of
// dynamic dimensions, all set to 0
MemRefType MT = cast<MemRefType>(self);
unsigned numDynamicDims = MT.getNumDynamicDims();
SmallVector<mlir::Value> dynamicSizes(numDynamicDims);
for (unsigned i = 0; i < numDynamicDims; ++i) {
dynamicSizes[i] = builder.create<mlir::arith::ConstantIndexOp>(loc, 0);
}
return mlir::memref::AllocOp::create(builder, loc, MT, dynamicSizes);
}
Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
Value b) const {
llvm_unreachable("TODO");
}
Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
}
Value createConjOp(Type self, OpBuilder &builder, Location loc,
Value a) const {
llvm_unreachable("TODO");
}
bool isMutable(Type self) const { return true; }
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
auto MT = cast<MemRefType>(self);
auto eltIface = dyn_cast<AutoDiffTypeInterface>(MT.getElementType());
if (!eltIface || eltIface.isMutable())
return failure();
Value zero = eltIface.createNullValue(builder, loc);
if (MT.getRank() == 0) {
memref::StoreOp::create(builder, loc, zero, val, ValueRange{});
return success();
}
Value c0 = arith::ConstantIndexOp::create(builder, loc, 0);
Value c1 = arith::ConstantIndexOp::create(builder, loc, 1);
SmallVector<Value> lbs(MT.getRank(), c0);
SmallVector<Value> steps(MT.getRank(), c1);
SmallVector<Value> ubs;
for (auto [i, d] : llvm::enumerate(MT.getShape())) {
ubs.push_back(
d == ShapedType::kDynamic
? memref::DimOp::create(builder, loc, val, i).getResult()
: arith::ConstantIndexOp::create(builder, loc, d).getResult());
}
scf::ParallelOp::create(builder, loc, lbs, ubs, steps,
[&](OpBuilder &b, Location l, ValueRange ivs) {
memref::StoreOp::create(b, l, zero, val, ivs);
});
return success();
}
bool isZero(Type self, Value val) const { return false; }
bool isZeroAttr(Type self, Attribute val) const { return false; }
};
} // namespace
void mlir::enzyme::registerMemRefDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) {
registerInterfaces(context);
MemRefType::attachInterface<MemRefAutoDiffTypeInterface>(*context);
MemRefType::attachInterface<MemRefClonableTypeInterface>(*context);
memref::LoadOp::attachInterface<LoadOpInterfaceReverse>(*context);
memref::StoreOp::attachInterface<StoreOpInterfaceReverse>(*context);
memref::SubViewOp::attachInterface<SubViewOpInterfaceReverse>(*context);
memref::AllocaScopeOp::attachInterface<AllocaScopeOpInterfaceReverse>(
*context);
});
}