Skip to content

Commit 22b8e85

Browse files
authored
Merge pull request #16 from EnzymeAD/mjp/unify-apis
Unify MLIR API with LLVM API
2 parents 46b6f9d + 752a3a7 commit 22b8e85

File tree

10 files changed

+53
-29
lines changed

10 files changed

+53
-29
lines changed

enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def ForwardDiffOp : Enzyme_Op<"fwddiff",
5555
}];
5656
}
5757

58-
def DiffOp : Enzyme_Op<"diff",
58+
def AutoDiffOp : Enzyme_Op<"autodiff",
5959
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
6060
let summary = "Perform reverse mode AD on a funcop";
6161
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity);

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ ForwardDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
5353
return success();
5454
}
5555

56-
LogicalResult DiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
56+
LogicalResult AutoDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
5757
// TODO: Verify that the result type is same as the type of the referenced
5858
// func.func op.
5959
auto global =

enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#include "Interfaces/GradientUtils.h"
1818
#include "Interfaces/GradientUtilsReverse.h"
1919

20+
// TODO: We need a way to zero out a memref (which linalg.fill does), but
21+
// ideally we wouldn't depend on the linalg dialect.
22+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2023
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2124
#include "mlir/IR/DialectRegistry.h"
2225
#include "mlir/Support/LogicalResult.h"
@@ -197,15 +200,10 @@ struct AllocOpInterfaceReverse
197200
memref::AllocOp> {
198201
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
199202
MGradientUtilsReverse *gutils,
200-
SmallVector<Value> caches) const {
201-
auto allocOp = cast<memref::AllocOp>(op);
202-
Value memref = allocOp.getMemref();
203-
}
203+
SmallVector<Value> caches) const {}
204204

205205
SmallVector<Value> cacheValues(Operation *op,
206206
MGradientUtilsReverse *gutils) const {
207-
auto allocOp = cast<memref::AllocOp>(op);
208-
209207
return SmallVector<Value>();
210208
}
211209

@@ -216,6 +214,15 @@ struct AllocOpInterfaceReverse
216214

217215
Value shadow = builder.create<memref::AllocOp>(
218216
op->getLoc(), newAllocOp.getType(), newAllocOp.getDynamicSizes());
217+
// Fill with zeros
218+
if (auto iface = dyn_cast<AutoDiffTypeInterface>(
219+
allocOp.getType().getElementType())) {
220+
Value zero = iface.createNullValue(builder, op->getLoc());
221+
builder.create<linalg::FillOp>(op->getLoc(), zero, shadow);
222+
} else {
223+
op->emitWarning() << "memref.alloc element type does not implement "
224+
"AutoDiffTypeInterface";
225+
}
219226
gutils->mapShadowValue(allocOp, shadow, builder);
220227
}
221228
};

enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,10 @@ FunctionOpInterface CloneFunctionWithReturns(
242242
mlir::Value oval = F.getFunctionBody().front().getArgument(i);
243243
if (constant_args[i] == DIFFE_TYPE_MLIR::CONSTANT)
244244
constants.insert(oval);
245-
else
245+
else if (constant_args[i] == DIFFE_TYPE_MLIR::OUT_DIFF)
246246
nonconstants.insert(oval);
247-
if (constant_args[i] == DIFFE_TYPE_MLIR::DUP_ARG ||
248-
constant_args[i] == DIFFE_TYPE_MLIR::DUP_NONEED) {
247+
else if (constant_args[i] == DIFFE_TYPE_MLIR::DUP_ARG ||
248+
constant_args[i] == DIFFE_TYPE_MLIR::DUP_NONEED) {
249249
mlir::Value val = blk.getArgument(i);
250250
mlir::Value dval;
251251
if (i == constant_args.size() - 1)

enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,16 @@ void MEnzymeLogic::handlePredecessors(Block *oBB, Block *newBB,
199199
OpBuilder revBuilder(reverseBB, reverseBB->end());
200200
if (oBB->hasNoPredecessors()) {
201201
SmallVector<mlir::Value> retargs;
202-
for (Value attribute : gutils->oldFunc.getFunctionBody().getArguments()) {
203-
Value attributeGradient = gutils->invertPointerM(attribute, revBuilder);
204-
retargs.push_back(attributeGradient);
202+
assert(gutils->ArgDiffeTypes.size() == gutils->oldFunc.getNumArguments() &&
203+
"Mismatch of activity array size vs # original function args");
204+
for (const auto &[diffeType, oldArg] :
205+
llvm::zip(gutils->ArgDiffeTypes,
206+
gutils->oldFunc.getFunctionBody().getArguments())) {
207+
if (diffeType == DIFFE_TYPE_MLIR::OUT_DIFF) {
208+
retargs.push_back(gutils->invertPointerM(oldArg, revBuilder));
209+
}
205210
}
206211
buildReturnOp(revBuilder, oBB->rbegin()->getLoc(), retargs);
207-
// revBuilder.create<func::ReturnOp>(oBB->rbegin()->getLoc(), retargs);
208212
} else {
209213
SmallVector<Block *> blocks;
210214
SmallVector<APInt> indices;
@@ -325,7 +329,7 @@ void MEnzymeLogic::initializeShadowValues(
325329
void MEnzymeLogic::differentiate(MGradientUtilsReverse *gutils,
326330
Region &oldRegion, Region &newRegion,
327331
bool parentRegion,
328-
buildReturnFunction buildFuncRetrunOp) {
332+
buildReturnFunction buildFuncReturnOp) {
329333
gutils->createReverseModeBlocks(oldRegion, newRegion, parentRegion);
330334

331335
SmallVector<mlir::Block *> dominatorToposortBlocks =
@@ -341,7 +345,7 @@ void MEnzymeLogic::differentiate(MGradientUtilsReverse *gutils,
341345
mapInvertArguments(oBB, reverseBB, gutils);
342346
handleReturns(oBB, newBB, reverseBB, gutils, parentRegion);
343347
visitChildren(oBB, reverseBB, gutils);
344-
handlePredecessors(oBB, newBB, reverseBB, gutils, buildFuncRetrunOp);
348+
handlePredecessors(oBB, newBB, reverseBB, gutils, buildFuncReturnOp);
345349
}
346350
}
347351

@@ -358,21 +362,21 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
358362
llvm_unreachable("Differentiating empty function");
359363
}
360364

361-
ReturnTypeMLIR returnValue = ReturnTypeMLIR::Tape;
365+
ReturnTypeMLIR returnValue = ReturnTypeMLIR::Args;
362366
MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
363367
*this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true,
364368
constants, returnValue, addedType, symbolTable);
365369

366370
Region &oldRegion = gutils->oldFunc.getFunctionBody();
367371
Region &newRegion = gutils->newFunc.getFunctionBody();
368372

369-
buildReturnFunction buildFuncRetrunOp = [](OpBuilder &builder, Location loc,
373+
buildReturnFunction buildFuncReturnOp = [](OpBuilder &builder, Location loc,
370374
SmallVector<Value> retargs) {
371375
builder.create<func::ReturnOp>(loc, retargs);
372376
return;
373377
};
374378

375-
differentiate(gutils, oldRegion, newRegion, true, buildFuncRetrunOp);
379+
differentiate(gutils, oldRegion, newRegion, true, buildFuncReturnOp);
376380

377381
auto nf = gutils->newFunc;
378382

enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
4343
originalToNewFn(originalToNewFn_),
4444
originalToNewFnOps(originalToNewFnOps_), symbolTable(symbolTable_) {
4545

46-
initInitializationBlock(invertedPointers_);
46+
initInitializationBlock(invertedPointers_, activevals_);
4747
}
4848

4949
// for(auto x : v.getUsers()){x->dump();} DEBUG
@@ -277,13 +277,24 @@ bool mlir::enzyme::MGradientUtilsReverse::hasInvertPointer(mlir::Value v) {
277277
}
278278

279279
void MGradientUtilsReverse::initInitializationBlock(
280-
BlockAndValueMapping invertedPointers_) {
280+
BlockAndValueMapping invertedPointers_,
281+
const SmallPtrSetImpl<Value> &activevals_) {
281282
initializationBlock = &*(this->newFunc.getFunctionBody().begin());
282283

283284
OpBuilder initializationBuilder(
284285
&*(this->newFunc.getFunctionBody().begin()),
285286
this->newFunc.getFunctionBody().begin()->begin());
286287

288+
for (Value activeval : activevals_) {
289+
if (auto iface = dyn_cast<AutoDiffTypeInterface>(activeval.getType())) {
290+
Value zero =
291+
iface.createNullValue(initializationBuilder, activeval.getLoc());
292+
mapInvertPointer(activeval, zero, initializationBuilder);
293+
} else {
294+
llvm_unreachable(
295+
"Type does not have an associated AutoDiffTypeInterface");
296+
}
297+
}
287298
for (auto const &x : invertedPointers_.getValueMap()) {
288299
if (auto iface = dyn_cast<AutoDiffTypeInterface>(x.first.getType())) {
289300
if (iface.requiresShadow()) {

enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ class MGradientUtilsReverse {
9595

9696
bool requiresShadow(Type t);
9797

98-
void initInitializationBlock(BlockAndValueMapping invertedPointers_);
98+
void initInitializationBlock(BlockAndValueMapping invertedPointers_,
99+
const SmallPtrSetImpl<mlir::Value> &activevals_);
99100

100101
bool onlyUsedInParentBlock(Value v);
101102

enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
178178

179179
{
180180
SmallVector<Operation *> toLower;
181-
op->walk([&](enzyme::DiffOp dop) {
181+
op->walk([&](enzyme::AutoDiffOp dop) {
182182
auto *symbolOp =
183183
symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr());
184184
auto callableOp = cast<FunctionOpInterface>(symbolOp);
@@ -188,7 +188,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
188188
});
189189

190190
for (auto T : toLower) {
191-
if (auto F = dyn_cast<enzyme::DiffOp>(T)) {
191+
if (auto F = dyn_cast<enzyme::AutoDiffOp>(T)) {
192192
HandleAutoDiffReverse(symbolTable, F);
193193
} else {
194194
llvm_unreachable("Illegal type");

enzyme/Enzyme/MLIR/Passes/LowerToLLVMEnzyme.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ void convertMemRefArgument(Location loc, Value primal,
100100
operands.push_back(memrefPrimal.stride(b, loc, pos));
101101
}
102102

103-
struct DiffOpLowering : public OpConversionPattern<enzyme::DiffOp> {
104-
using OpConversionPattern<enzyme::DiffOp>::OpConversionPattern;
103+
struct DiffOpLowering : public OpConversionPattern<enzyme::AutoDiffOp> {
104+
using OpConversionPattern<enzyme::AutoDiffOp>::OpConversionPattern;
105105

106106
LogicalResult
107-
matchAndRewrite(enzyme::DiffOp op, OpAdaptor adaptor,
107+
matchAndRewrite(enzyme::AutoDiffOp op, OpAdaptor adaptor,
108108
ConversionPatternRewriter &rewriter) const override {
109109
auto moduleOp = op->getParentOfType<ModuleOp>();
110110
Location loc = op.getLoc();
@@ -263,7 +263,7 @@ struct LowerToLLVMEnzymePass
263263
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
264264

265265
ConversionTarget target(*context);
266-
target.addIllegalOp<enzyme::DiffOp>();
266+
target.addIllegalOp<enzyme::AutoDiffOp>();
267267
target.addLegalDialect<LLVM::LLVMDialect>();
268268

269269
if (failed(applyPartialConversion(getOperation(), target,

enzyme/Enzyme/MLIR/enzymemlir-opt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ int main(int argc, char **argv) {
5858
registry.insert<mlir::NVVM::NVVMDialect>();
5959
registry.insert<mlir::omp::OpenMPDialect>();
6060
registry.insert<mlir::math::MathDialect>();
61+
registry.insert<mlir::linalg::LinalgDialect>();
6162
registry.insert<DLTIDialect>();
6263

6364
registry.insert<mlir::enzyme::EnzymeDialect>();

0 commit comments

Comments
 (0)