Skip to content

Commit

Permalink
further calling convention fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 12, 2019
1 parent fbc6d9a commit 4e5c31f
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 38 deletions.
53 changes: 39 additions & 14 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ llvm::cl::opt<bool> enzyme_print("enzyme_print", cl::init(false), cl::Hidden,

//! return structtype if recursive function
std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set<unsigned>& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed) {
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/>, std::pair<Function*,StructType*>> cachedfunctions;
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/>, bool> cachedfinished;
auto tup = std::make_tuple(todiff, std::set<unsigned>(constant_args.begin(), constant_args.end()), differentialReturn);
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair<Function*,StructType*>> cachedfunctions;
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished;
auto tup = std::make_tuple(todiff, std::set<unsigned>(constant_args.begin(), constant_args.end()), differentialReturn, returnUsed);
if (cachedfunctions.find(tup) != cachedfunctions.end()) {
return cachedfunctions[tup];
}
Expand Down Expand Up @@ -110,13 +110,19 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
gutils->forceAugmentedReturns();

//! Explicitly handle all returns first to ensure that all instructions know whether or not they are used
SmallPtrSet<Instruction*, 4> returnuses;

for(BasicBlock* BB: gutils->originalBlocks) {
if(auto ri = dyn_cast<ReturnInst>(BB->getTerminator())) {
auto oldval = ri->getReturnValue();
Value* rt = UndefValue::get(gutils->newFunc->getReturnType());
IRBuilder <>ib(ri);
if (oldval && returnUsed)
if (oldval && returnUsed) {
rt = ib.CreateInsertValue(rt, oldval, {1});
if (Instruction* inst = dyn_cast<Instruction>(rt)) {
returnuses.insert(inst);
}
}
ib.CreateRet(rt);
gutils->erase(ri);
/*
Expand Down Expand Up @@ -333,6 +339,16 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul

bool subretused = op->getNumUses() != 0;
bool subdifferentialreturn = !gutils->isConstantValue(op) && subretused;

//! We only need to cache something if it is used in a non return setting (since the backard pass doesnt need to use it if just returned)
bool shouldCache = false;//outermostAugmentation;
for(auto use : op->users()) {
if (!isa<Instruction>(use) || returnuses.find(cast<Instruction>(use)) == returnuses.end()) {
llvm::errs() << "shouldCache for " << *op << " use " << *use << "\n";
shouldCache = true;
}
}

auto newcalled = CreateAugmentedPrimal(dyn_cast<Function>(called), AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused).first;
auto augmentcall = BuilderZ.CreateCall(newcalled, args);
assert(augmentcall->getType()->isStructTy());
Expand All @@ -348,6 +364,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
gutils->erase(cast<Instruction>(tp));
tp = UndefValue::get(tpt);
}

gutils->addMalloc(BuilderZ, tp);

if (subretused) {
Expand All @@ -359,19 +376,25 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
}
assert(op->getType() == rv->getType());

gutils->addMalloc(BuilderZ, rv);
if (shouldCache) {
gutils->addMalloc(BuilderZ, rv);
}

if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && subdifferentialreturn) {
assert(cast<StructType>(augmentcall->getType())->getNumElements() == 3);

auto antiptr = cast<Instruction>(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() ));
auto placeholder = cast<PHINode>(gutils->invertedPointers[op]);
if (I != E && placeholder == &*I) I++;
gutils->invertedPointers.erase(op);

assert(cast<StructType>(augmentcall->getType())->getNumElements() == 3);
auto antiptr = cast<Instruction>(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() ));
gutils->invertedPointers[rv] = antiptr;
placeholder->replaceAllUsesWith(antiptr);

if (shouldCache) {
gutils->addMalloc(BuilderZ, antiptr);
}

gutils->erase(placeholder);
gutils->invertedPointers[rv] = antiptr;
gutils->addMalloc(BuilderZ, antiptr);
} else {
if (cast<StructType>(augmentcall->getType())->getNumElements() != 2) {
llvm::errs() << "old called: " << *called << "\n";
Expand All @@ -380,6 +403,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
llvm::errs() << "op subdifferentialreturn: " << subdifferentialreturn << "\n";
}
assert(cast<StructType>(augmentcall->getType())->getNumElements() == 2);

}

gutils->replaceAWithB(op,rv);
Expand Down Expand Up @@ -1197,8 +1221,9 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r

//TODO consider what to do if called == nullptr for augmentation
if (modifyPrimal && called) {
bool subdifferentialreturn = !gutils->isConstantValue(op);
auto fnandtapetype = CreateAugmentedPrimal(cast<Function>(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/ op->getNumUses() != 0 && !op->doesNotAccessMemory());
bool subretused = op->getNumUses() != 0;
bool subdifferentialreturn = !gutils->isConstantValue(op) && subretused;
auto fnandtapetype = CreateAugmentedPrimal(cast<Function>(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused);
if (topLevel) {
Function* newcalled = fnandtapetype.first;
augmentcall = BuilderZ.CreateCall(newcalled, pre_args);
Expand Down Expand Up @@ -1233,12 +1258,12 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r
}
} else {
tape = gutils->addMalloc(BuilderZ, tape);

if (!tape->getType()->isStructTy()) {
llvm::errs() << "newFunc: " << *gutils->newFunc << "\n";
llvm::errs() << "augment: " << *fnandtapetype.first << "\n";
llvm::errs() << "op: " << *op << "\n";
llvm::errs() << "tape: " << *tape << "\n";

llvm::errs() << "tape: " << *tape << "\n";
}
assert(tape->getType()->isStructTy());

Expand Down
42 changes: 18 additions & 24 deletions enzyme/test/Enzyme/badcallused.ll
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ attributes #1 = { noinline nounwind uwtable }

; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { { {}, i1, i1 }, i1, i1 } @augmented_subf(double* %x, double* %"x'")
; CHECK-NEXT: %1 = extractvalue { { {}, i1, i1 }, i1, i1 } %0, 0
; CHECK-NEXT: %2 = extractvalue { { {}, i1, i1 }, i1, i1 } %0, 1
; CHECK-NEXT: %sel = select i1 %2, double 2.000000e+00, double 3.000000e+00
; CHECK-NEXT: %0 = call { { {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'")
; CHECK-NEXT: %1 = extractvalue { { {} }, i1, i1 } %0, 1
; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00
; CHECK-NEXT: store double %sel, double* %x, align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, i1, i1 } %1)
; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef)
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

Expand All @@ -66,29 +65,24 @@ attributes #1 = { noinline nounwind uwtable }
; CHECK-NEXT: ret { {}, i1, i1 } %3
; CHECK-NEXT: }

; CHECK: define internal { { {}, i1, i1 }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'")
; CHECK: define internal { { {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = alloca { { {}, i1, i1 }, i1, i1 }
; CHECK-NEXT: %1 = getelementptr { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0, i32 0, i32 0
; CHECK-NEXT: %2 = load double, double* %x, align 8
; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00
; CHECK-NEXT: %0 = alloca { { {} }, i1, i1 }
; CHECK-NEXT: %1 = load double, double* %x, align 8
; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00
; CHECK-NEXT: store double %mul, double* %x, align 8
; CHECK-NEXT: %3 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'")
; CHECK-NEXT: %4 = extractvalue { {}, i1, i1 } %3, 1
; CHECK-NEXT: %5 = getelementptr { {}, i1, i1 }, { {}, i1, i1 }* %1, i32 0, i32 1
; CHECK-NEXT: store i1 %4, i1* %5
; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %3, 2
; CHECK-NEXT: %6 = getelementptr { {}, i1, i1 }, { {}, i1, i1 }* %1, i32 0, i32 2
; CHECK-NEXT: store i1 %antiptr_call, i1* %6
; CHECK-NEXT: %7 = getelementptr { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0, i32 0, i32 1
; CHECK-NEXT: store i1 %4, i1* %7
; CHECK-NEXT: %8 = getelementptr { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0, i32 0, i32 2
; CHECK-NEXT: store i1 %antiptr_call, i1* %8
; CHECK-NEXT: %[[toret:.+]] = load { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0
; CHECK-NEXT: ret { { {}, i1, i1 }, i1, i1 } %[[toret]]
; CHECK-NEXT: %2 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'")
; CHECK-NEXT: %3 = extractvalue { {}, i1, i1 } %2, 1
; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %2, 2
; CHECK-NEXT: %4 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 1
; CHECK-NEXT: store i1 %3, i1* %4
; CHECK-NEXT: %5 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 2
; CHECK-NEXT: store i1 %antiptr_call, i1* %5
; CHECK-NEXT: %[[toret:.+]] = load { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0
; CHECK-NEXT: ret { { {} }, i1, i1 } %[[toret]]
; CHECK-NEXT: }

; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {}, i1, i1 } %tapeArg)
; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
; CHECK-NEXT: %1 = load double, double* %"x'"
Expand Down
128 changes: 128 additions & 0 deletions enzyme/test/Enzyme/badcallused2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%arrayidx = getelementptr inbounds double, double* %x, i64 1
store double 3.000000e+00, double* %arrayidx, align 8
%0 = load double, double* %x, align 8
%cmp = fcmp fast oeq double %0, 2.000000e+00
ret i1 %cmp
}

define dso_local zeroext i1 @omegasubf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%arrayidx = getelementptr inbounds double, double* %x, i64 1
store double 3.000000e+00, double* %arrayidx, align 8
%0 = load double, double* %x, align 8
%cmp = fcmp fast oeq double %0, 2.000000e+00
ret i1 %cmp
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%0 = load double, double* %x, align 8
%mul = fmul fast double %0, 2.000000e+00
store double %mul, double* %x, align 8
%call = tail call zeroext i1 @omegasubf(double* %x)
%call2 = tail call zeroext i1 @metasubf(double* %x)
ret i1 %call2
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local void @f(double* nocapture %x) #0 {
entry:
%call = tail call zeroext i1 @subf(double* %x)
%sel = select i1 %call, double 2.000000e+00, double 3.000000e+00
store double %sel, double* %x, align 8
ret void
}

; Function Attrs: noinline nounwind uwtable
define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 {
entry:
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp)
ret double %call
}

declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed_addr

attributes #0 = { noinline norecurse nounwind uwtable }
attributes #1 = { noinline nounwind uwtable }

; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { { {}, {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'")
; CHECK-NEXT: %1 = extractvalue { { {}, {} }, i1, i1 } %0, 1
; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00
; CHECK-NEXT: store double %sel, double* %x, align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef)
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal { {}, i1, i1 } @augmented_metasubf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = alloca { {}, i1, i1 }
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
; CHECK-NEXT: %1 = load double, double* %x, align 8
; CHECK-NEXT: %cmp = fcmp fast oeq double %1, 2.000000e+00
; CHECK-NEXT: %2 = getelementptr { {}, i1, i1 }, { {}, i1, i1 }* %0, i32 0, i32 1
; CHECK-NEXT: store i1 %cmp, i1* %2
; CHECK-NEXT: %3 = load { {}, i1, i1 }, { {}, i1, i1 }* %0
; CHECK-NEXT: ret { {}, i1, i1 } %3
; CHECK-NEXT: }

; CHECK: define internal { {} } @augmented_omegasubf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
; CHECK-NEXT: ret { {} } undef
; CHECK-NEXT: }

; CHECK: define internal { { {}, {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = alloca { { {}, {} }, i1, i1 }
; CHECK-NEXT: %1 = load double, double* %x, align 8
; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00
; CHECK-NEXT: store double %mul, double* %x, align 8
; CHECK-NEXT: %2 = call { {} } @augmented_omegasubf(double* %x, double* %"x'")
; CHECK-NEXT: %3 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'")
; CHECK-NEXT: %4 = extractvalue { {}, i1, i1 } %3, 1
; CHECK-NEXT: %antiptr_call2 = extractvalue { {}, i1, i1 } %3, 2
; CHECK-NEXT: %5 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 1
; CHECK-NEXT: store i1 %4, i1* %5
; CHECK-NEXT: %6 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 2
; CHECK-NEXT: store i1 %antiptr_call2, i1* %6
; CHECK-NEXT: %[[toret:.+]] = load { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0
; CHECK-NEXT: ret { { {}, {} }, i1, i1 } %[[toret]]
; CHECK-NEXT: }

; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
; CHECK-NEXT: %1 = call {} @diffeomegasubf(double* %x, double* %"x'", {} undef)
; CHECK-NEXT: %2 = load double, double* %"x'"
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %m0diffe = fmul fast double %2, 2.000000e+00
; CHECK-NEXT: %3 = load double, double* %"x'"
; CHECK-NEXT: %4 = fadd fast double %3, %m0diffe
; CHECK-NEXT: store double %4, double* %"x'"
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal {} @diffemetasubf(double* nocapture %x, double* %"x'", {} %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[tostore:.+]] = getelementptr inbounds double, double* %"x'", i64 1
; CHECK-NEXT: store double 0.000000e+00, double* %[[tostore]], align 8
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal {} @diffeomegasubf(double* nocapture %x, double* %"x'", {} %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[tostore:.+]] = getelementptr inbounds double, double* %"x'", i64 1
; CHECK-NEXT: store double 0.000000e+00, double* %[[tostore]], align 8
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

0 comments on commit 4e5c31f

Please sign in to comment.