change getShadowType
diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 3fae773..a1bc038 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp
@@ -214,8 +214,13 @@ } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + if (width == 1) + return self; + + auto MRT = llvm::cast<MemRefType>(self); + SmallVector<int64_t> out_shape = {width}; + out_shape.append(MRT.getShape().begin(), MRT.getShape().end()); + return MRT.clone(out_shape); } Value createConjOp(Type self, OpBuilder &builder, Location loc,