| //===- MemRefAutoDiffOpInterfaceImpl.cpp - Interface external model -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file contains the external model implementation of the automatic |
| // differentiation op interfaces for the upstream MLIR memref dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Implementations/CoreDialectsAutoDiffImplementations.h" |
| #include "Interfaces/AutoDiffOpInterface.h" |
| #include "Interfaces/AutoDiffTypeInterface.h" |
| #include "Interfaces/GradientUtils.h" |
| #include "Interfaces/GradientUtilsReverse.h" |
| |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/IR/DialectRegistry.h" |
| #include "mlir/Support/LogicalResult.h" |
| |
| using namespace mlir; |
| using namespace mlir::enzyme; |
| |
| namespace { |
| #include "Implementations/MemRefDerivatives.inc" |
| |
| struct LoadOpInterfaceReverse |
| : public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse, |
| memref::LoadOp> { |
| LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils, |
| SmallVector<Value> caches) const { |
| auto loadOp = cast<memref::LoadOp>(op); |
| Value memref = loadOp.getMemref(); |
| |
| if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) { |
| if (!gutils->isConstantValue(loadOp) && |
| !gutils->isConstantValue(memref)) { |
| Value gradient = gutils->diffe(loadOp, builder); |
| Value memrefGradient = gutils->popCache(caches.front(), builder); |
| |
| SmallVector<Value> retrievedArguments; |
| for (Value cache : ValueRange(caches).drop_front(1)) { |
| Value retrievedValue = gutils->popCache(cache, builder); |
| retrievedArguments.push_back(retrievedValue); |
| } |
| |
| if (!gutils->AtomicAdd) { |
| Value loadedGradient = |
| memref::LoadOp::create(builder, loadOp.getLoc(), memrefGradient, |
| ArrayRef<Value>(retrievedArguments)); |
| Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(), |
| loadedGradient, gradient); |
| memref::StoreOp::create(builder, loadOp.getLoc(), addedGradient, |
| memrefGradient, |
| ArrayRef<Value>(retrievedArguments)); |
| } else { |
| memref::AtomicRMWOp::create( |
| builder, loadOp.getLoc(), arith::AtomicRMWKind::addf, gradient, |
| memrefGradient, ArrayRef<Value>(retrievedArguments)); |
| } |
| } |
| } |
| return success(); |
| } |
| |
| SmallVector<Value> cacheValues(Operation *op, |
| MGradientUtilsReverse *gutils) const { |
| auto loadOp = cast<memref::LoadOp>(op); |
| Value memref = loadOp.getMemref(); |
| ValueRange indices = loadOp.getIndices(); |
| if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) { |
| if (!gutils->isConstantValue(loadOp) && |
| !gutils->isConstantValue(memref)) { |
| OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); |
| SmallVector<Value> caches; |
| caches.push_back(gutils->initAndPushCache( |
| gutils->invertPointerM(memref, cacheBuilder), cacheBuilder)); |
| for (Value v : indices) { |
| caches.push_back(gutils->initAndPushCache( |
| gutils->getNewFromOriginal(v), cacheBuilder)); |
| } |
| return caches; |
| } |
| } |
| return SmallVector<Value>(); |
| } |
| |
| void createShadowValues(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils) const { |
| // auto loadOp = cast<memref::LoadOp>(op); |
| // Value memref = loadOp.getMemref(); |
| // Value shadow = gutils->getShadowValue(memref); |
| // Do nothing yet. In the future support memref<memref<...>> |
| } |
| }; |
| |
| struct StoreOpInterfaceReverse |
| : public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse, |
| memref::StoreOp> { |
| LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils, |
| SmallVector<Value> caches) const { |
| auto storeOp = cast<memref::StoreOp>(op); |
| Value val = storeOp.getValue(); |
| Value memref = storeOp.getMemref(); |
| // ValueRange indices = storeOp.getIndices(); |
| |
| auto iface = cast<AutoDiffTypeInterface>(val.getType()); |
| |
| if (!gutils->isConstantValue(memref)) { |
| Value memrefGradient = gutils->popCache(caches.front(), builder); |
| |
| SmallVector<Value> retrievedArguments; |
| for (Value cache : ValueRange(caches).drop_front(1)) { |
| Value retrievedValue = gutils->popCache(cache, builder); |
| retrievedArguments.push_back(retrievedValue); |
| } |
| |
| if (!iface.isMutable()) { |
| if (!gutils->isConstantValue(val)) { |
| Value loadedGradient = |
| memref::LoadOp::create(builder, storeOp.getLoc(), memrefGradient, |
| ArrayRef<Value>(retrievedArguments)); |
| gutils->addToDiffe(val, loadedGradient, builder); |
| } |
| |
| auto zero = |
| cast<AutoDiffTypeInterface>(gutils->getShadowType(val.getType())) |
| .createNullValue(builder, op->getLoc()); |
| |
| memref::StoreOp::create(builder, storeOp.getLoc(), zero, memrefGradient, |
| ArrayRef<Value>(retrievedArguments)); |
| } |
| } |
| return success(); |
| } |
| |
| SmallVector<Value> cacheValues(Operation *op, |
| MGradientUtilsReverse *gutils) const { |
| auto storeOp = cast<memref::StoreOp>(op); |
| Value memref = storeOp.getMemref(); |
| ValueRange indices = storeOp.getIndices(); |
| Value val = storeOp.getValue(); |
| if (auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType())) { |
| if (!gutils->isConstantValue(memref)) { |
| OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); |
| SmallVector<Value> caches; |
| caches.push_back(gutils->initAndPushCache( |
| gutils->invertPointerM(memref, cacheBuilder), cacheBuilder)); |
| for (Value v : indices) { |
| caches.push_back(gutils->initAndPushCache( |
| gutils->getNewFromOriginal(v), cacheBuilder)); |
| } |
| return caches; |
| } |
| } |
| return SmallVector<Value>(); |
| } |
| |
| void createShadowValues(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils) const { |
| // auto storeOp = cast<memref::StoreOp>(op); |
| // Value memref = storeOp.getMemref(); |
| // Value shadow = gutils->getShadowValue(memref); |
| // Do nothing yet. In the future support memref<memref<...>> |
| } |
| }; |
| |
| struct SubViewOpInterfaceReverse |
| : public ReverseAutoDiffOpInterface::ExternalModel< |
| SubViewOpInterfaceReverse, memref::SubViewOp> { |
| LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils, |
| SmallVector<Value> caches) const { |
| return success(); |
| } |
| |
| SmallVector<Value> cacheValues(Operation *op, |
| MGradientUtilsReverse *gutils) const { |
| return SmallVector<Value>(); |
| } |
| |
| void createShadowValues(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils) const { |
| auto subviewOp = cast<memref::SubViewOp>(op); |
| auto newSubviewOp = cast<memref::SubViewOp>(gutils->getNewFromOriginal(op)); |
| if (!gutils->isConstantValue(subviewOp.getSource())) { |
| Value shadow = memref::SubViewOp::create( |
| builder, op->getLoc(), newSubviewOp.getType(), |
| gutils->invertPointerM(subviewOp.getSource(), builder), |
| newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(), |
| newSubviewOp.getMixedStrides()); |
| gutils->setInvertedPointer(subviewOp, shadow); |
| } |
| } |
| }; |
| |
| struct AllocaScopeOpInterfaceReverse |
| : public ReverseAutoDiffOpInterface::ExternalModel< |
| AllocaScopeOpInterfaceReverse, memref::AllocaScopeOp> { |
| LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils, |
| SmallVector<Value> caches) const { |
| auto scopeOp = cast<memref::AllocaScopeOp>(op); |
| |
| SmallVector<bool> resultsActive(scopeOp->getNumResults(), false); |
| SmallVector<Value> incomingGradients; |
| incomingGradients.reserve(scopeOp->getNumResults()); |
| for (OpResult result : scopeOp->getResults()) { |
| bool active = !gutils->isConstantValue(result); |
| resultsActive[result.getResultNumber()] = active; |
| if (active) { |
| incomingGradients.push_back(gutils->diffe(result, builder)); |
| gutils->zeroDiffe(result, builder); |
| } |
| } |
| |
| Region &scopeRegion = scopeOp.getBodyRegion(); |
| SmallVector<Value> capturedInputs; |
| scopeRegion.walk([&](Operation *inner) { |
| for (Value operand : inner->getOperands()) { |
| Region *defRegion = operand.getParentRegion(); |
| if (!defRegion || scopeRegion.isAncestor(defRegion)) |
| continue; |
| if (gutils->isConstantValue(operand)) |
| continue; |
| auto iface = dyn_cast<AutoDiffTypeInterface>(operand.getType()); |
| if (!iface || iface.isMutable()) |
| continue; |
| if (!llvm::is_contained(capturedInputs, operand)) |
| capturedInputs.push_back(operand); |
| } |
| }); |
| |
| SmallVector<Value> capturedPre; |
| capturedPre.reserve(capturedInputs.size()); |
| for (Value v : capturedInputs) { |
| capturedPre.push_back(gutils->diffe(v, builder)); |
| gutils->zeroDiffe(v, builder); |
| } |
| |
| auto newScope = cast<memref::AllocaScopeOp>(gutils->getNewFromOriginal(op)); |
| newScope->moveBefore(builder.getInsertionBlock(), |
| builder.getInsertionPoint()); |
| |
| Block &newBody = newScope.getBodyRegion().front(); |
| OpBuilder bodyBuilder(newBody.getTerminator()); |
| |
| Block &oldBody = scopeOp.getBodyRegion().front(); |
| bool valid = true; |
| |
| for (Operation &innerOp : oldBody.getOperations()) { |
| for (Value result : innerOp.getResults()) { |
| if (!gutils->isConstantValue(result)) { |
| auto iface = dyn_cast<AutoDiffTypeInterface>(result.getType()); |
| if (iface && !iface.isMutable()) |
| gutils->zeroDiffe(result, bodyBuilder); |
| } |
| } |
| } |
| |
| Operation *term = oldBody.getTerminator(); |
| unsigned incomingIdx = 0; |
| for (auto indexedOperand : llvm::enumerate(term->getOperands())) { |
| if (!resultsActive[indexedOperand.index()]) |
| continue; |
| |
| Value operand = indexedOperand.value(); |
| if (!gutils->isConstantValue(operand)) |
| gutils->addToDiffe(operand, incomingGradients[incomingIdx], |
| bodyBuilder); |
| incomingIdx++; |
| } |
| |
| auto first = oldBody.rbegin(); |
| first++; |
| |
| for (auto it = first; it != oldBody.rend(); it++) { |
| valid &= gutils->Logic.visitChild(&*it, bodyBuilder, gutils).succeeded(); |
| } |
| |
| if (capturedInputs.empty()) |
| return success(valid); |
| |
| SmallVector<Value> contributions; |
| contributions.reserve(capturedInputs.size()); |
| for (Value v : capturedInputs) |
| contributions.push_back(gutils->diffe(v, bodyBuilder)); |
| |
| unsigned numPrimal = newScope->getNumResults(); |
| Operation *bodyTerm = newBody.getTerminator(); |
| SmallVector<Value> retVals(bodyTerm->getOperands().begin(), |
| bodyTerm->getOperands().end()); |
| retVals.append(contributions.begin(), contributions.end()); |
| |
| SmallVector<Type> newResultTypes(newScope->getResultTypes().begin(), |
| newScope->getResultTypes().end()); |
| for (Value v : capturedInputs) |
| newResultTypes.push_back(gutils->getShadowType(v.getType())); |
| |
| OpBuilder scopeBuilder(newScope); |
| auto extScope = memref::AllocaScopeOp::create(scopeBuilder, op->getLoc(), |
| newResultTypes); |
| extScope.getBodyRegion().takeBody(newScope.getBodyRegion()); |
| |
| Block &extBody = extScope.getBodyRegion().front(); |
| Operation *movedTerm = extBody.getTerminator(); |
| OpBuilder termBuilder(movedTerm); |
| memref::AllocaScopeReturnOp::create(termBuilder, movedTerm->getLoc(), |
| retVals); |
| gutils->erase(movedTerm); |
| |
| for (unsigned i = 0; i < numPrimal; ++i) { |
| newScope->getResult(i).replaceAllUsesWith(extScope.getResult(i)); |
| gutils->originalToNewFn.map(scopeOp->getResult(i), extScope.getResult(i)); |
| } |
| gutils->erase(newScope); |
| |
| builder.setInsertionPointAfter(extScope); |
| for (auto indexed : llvm::enumerate(capturedInputs)) { |
| Value v = indexed.value(); |
| gutils->setDiffe(v, capturedPre[indexed.index()], builder); |
| gutils->addToDiffe(v, extScope.getResult(numPrimal + indexed.index()), |
| builder); |
| } |
| |
| return success(valid); |
| } |
| |
| SmallVector<Value> cacheValues(Operation *op, |
| MGradientUtilsReverse *gutils) const { |
| return SmallVector<Value>(); |
| } |
| |
| void createShadowValues(Operation *op, OpBuilder &builder, |
| MGradientUtilsReverse *gutils) const {} |
| }; |
| |
| class MemRefClonableTypeInterface |
| : public ClonableTypeInterface::ExternalModel<MemRefClonableTypeInterface, |
| MemRefType> { |
| |
| public: |
| mlir::Value cloneValue(mlir::Type self, OpBuilder &builder, |
| Value value) const { |
| MemRefType MT = cast<MemRefType>(self); |
| SmallVector<Value> dynamicSizes; |
| |
| for (auto [i, s] : llvm::enumerate(MT.getShape())) { |
| if (s == ShapedType::kDynamic) { |
| Value dim = arith::ConstantIndexOp::create(builder, value.getLoc(), i); |
| dynamicSizes.push_back( |
| memref::DimOp::create(builder, value.getLoc(), value, dim)); |
| } |
| } |
| |
| auto clone = |
| memref::AllocOp::create(builder, value.getLoc(), self, dynamicSizes); |
| memref::CopyOp::create(builder, value.getLoc(), value, clone); |
| |
| return clone; |
| } |
| |
| void freeClonedValue(mlir::Type self, OpBuilder &builder, Value value) const { |
| memref::DeallocOp::create(builder, value.getLoc(), value); |
| }; |
| }; |
| |
| class MemRefAutoDiffTypeInterface |
| : public AutoDiffTypeInterface::ExternalModel<MemRefAutoDiffTypeInterface, |
| MemRefType> { |
| public: |
| mlir::Attribute createNullAttr(mlir::Type self) const { |
| llvm_unreachable("Cannot create null of memref (todo polygeist null)"); |
| } |
| mlir::Value createNullValue(mlir::Type self, OpBuilder &builder, |
| Location loc) const { |
| // Create a memref of the given type with the required number of |
| // dynamic dimensions, all set to 0 |
| MemRefType MT = cast<MemRefType>(self); |
| unsigned numDynamicDims = MT.getNumDynamicDims(); |
| SmallVector<mlir::Value> dynamicSizes(numDynamicDims); |
| for (unsigned i = 0; i < numDynamicDims; ++i) { |
| dynamicSizes[i] = builder.create<mlir::arith::ConstantIndexOp>(loc, 0); |
| } |
| return mlir::memref::AllocOp::create(builder, loc, MT, dynamicSizes); |
| } |
| |
| Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a, |
| Value b) const { |
| llvm_unreachable("TODO"); |
| } |
| |
| Type getShadowType(Type self, unsigned width) const { |
| assert(width == 1 && "unsupported width != 1"); |
| return self; |
| } |
| |
| Value createConjOp(Type self, OpBuilder &builder, Location loc, |
| Value a) const { |
| llvm_unreachable("TODO"); |
| } |
| |
| bool isMutable(Type self) const { return true; } |
| |
| LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, |
| Value val) const { |
| auto MT = cast<MemRefType>(self); |
| auto eltIface = dyn_cast<AutoDiffTypeInterface>(MT.getElementType()); |
| if (!eltIface || eltIface.isMutable()) |
| return failure(); |
| Value zero = eltIface.createNullValue(builder, loc); |
| |
| if (MT.getRank() == 0) { |
| memref::StoreOp::create(builder, loc, zero, val, ValueRange{}); |
| return success(); |
| } |
| |
| Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); |
| Value c1 = arith::ConstantIndexOp::create(builder, loc, 1); |
| |
| SmallVector<Value> lbs(MT.getRank(), c0); |
| SmallVector<Value> steps(MT.getRank(), c1); |
| SmallVector<Value> ubs; |
| for (auto [i, d] : llvm::enumerate(MT.getShape())) { |
| ubs.push_back( |
| d == ShapedType::kDynamic |
| ? memref::DimOp::create(builder, loc, val, i).getResult() |
| : arith::ConstantIndexOp::create(builder, loc, d).getResult()); |
| } |
| |
| scf::ParallelOp::create(builder, loc, lbs, ubs, steps, |
| [&](OpBuilder &b, Location l, ValueRange ivs) { |
| memref::StoreOp::create(b, l, zero, val, ivs); |
| }); |
| return success(); |
| } |
| |
| bool isZero(Type self, Value val) const { return false; } |
| bool isZeroAttr(Type self, Attribute val) const { return false; } |
| }; |
| } // namespace |
| |
| void mlir::enzyme::registerMemRefDialectAutoDiffInterface( |
| DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) { |
| registerInterfaces(context); |
| MemRefType::attachInterface<MemRefAutoDiffTypeInterface>(*context); |
| MemRefType::attachInterface<MemRefClonableTypeInterface>(*context); |
| |
| memref::LoadOp::attachInterface<LoadOpInterfaceReverse>(*context); |
| memref::StoreOp::attachInterface<StoreOpInterfaceReverse>(*context); |
| memref::SubViewOp::attachInterface<SubViewOpInterfaceReverse>(*context); |
| memref::AllocaScopeOp::attachInterface<AllocaScopeOpInterfaceReverse>( |
| *context); |
| }); |
| } |