Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix calling convention and other small bugs #16

Merged
merged 12 commits into from
Oct 15, 2019
Prev Previous commit
Next Next commit
Fix various calling convention issues
  • Loading branch information
wsmoses committed Oct 12, 2019
commit 374122805137c1660adeaa289bb4f5fd29448a15
317 changes: 200 additions & 117 deletions enzyme/Enzyme/EnzymeLogic.cpp

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
AM.registerPass([] { return MemorySSAAnalysis(); });
AM.registerPass([] { return DominatorTreeAnalysis(); });
AM.registerPass([] { return MemoryDependenceAnalysis(); });
AM.registerPass([] { return LoopAnalysis(); });
AM.registerPass([] { return OptimizationRemarkEmitterAnalysis(); });
#if LLVM_VERSION_MAJOR > 6
AM.registerPass([] { return PhiValuesAnalysis(); });
#endif
Expand Down Expand Up @@ -520,14 +522,14 @@ Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, TargetLibraryInf
ArgTypes.push_back(additionalArg);
}
Type* RetType = StructType::get(F->getContext(), RetTypes);
if (returnValue == ReturnType::TapeAndReturns) {
if (returnValue == ReturnType::TapeAndReturns || returnValue == ReturnType::Tape) {
RetTypes.clear();
RetTypes.push_back(Type::getInt8PtrTy(F->getContext()));
if (!F->getReturnType()->isVoidTy()) {
RetTypes.push_back(F->getReturnType());
if (F->getReturnType()->isPointerTy() || F->getReturnType()->isIntegerTy())
RetTypes.push_back(F->getReturnType());
}
if (!F->getReturnType()->isVoidTy() && returnValue == ReturnType::TapeAndReturns) {
RetTypes.push_back(F->getReturnType());
if ( (F->getReturnType()->isPointerTy() || F->getReturnType()->isIntegerTy()) && differentialReturn)
RetTypes.push_back(F->getReturnType());
}
RetType = StructType::get(F->getContext(), RetTypes);
}

Expand Down
12 changes: 11 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
return val;
} else if (auto cint = dyn_cast<ConstantInt>(val)) {
if (cint->isZero()) return cint;
//this is extra
if (cint->isOne()) return cint;
}

Expand Down Expand Up @@ -370,6 +369,15 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
IRBuilder <> bb(arg);
auto li = bb.CreateLoad(invertPointerM(arg->getOperand(0), bb), arg->getName()+"'ipl");
li->setAlignment(arg->getAlignment());
li->setVolatile(arg->isVolatile());
li->setOrdering(arg->getOrdering());
li->setSyncScopeID(arg->getSyncScopeID ());
invertedPointers[arg] = li;
return lookupM(invertedPointers[arg], BuilderM);
} else if (auto arg = dyn_cast<BinaryOperator>(val)) {
assert(arg->getType()->isIntOrIntVectorTy());
IRBuilder <> bb(arg);
auto li = bb.CreateBinOp(arg->getOpcode(), invertPointerM(arg->getOperand(0), bb), invertPointerM(arg->getOperand(1), bb), arg->getName());
invertedPointers[arg] = li;
return lookupM(invertedPointers[arg], BuilderM);
} else if (auto arg = dyn_cast<GetElementPtrInst>(val)) {
Expand All @@ -381,6 +389,8 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
invertargs.push_back(b);
}
auto result = bb.CreateGEP(invertPointerM(arg->getPointerOperand(), bb), invertargs, arg->getName()+"'ipge");
if (auto gep = dyn_cast<GetElementPtrInst>(result))
gep->setIsInBounds(arg->isInBounds());
invertedPointers[arg] = result;
return lookupM(invertedPointers[arg], BuilderM);
}
Expand Down
40 changes: 26 additions & 14 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class GradientUtils {
replaceAWithB(placeholder, anti);
erase(placeholder);

anti = addMalloc<Instruction>(bb, anti);
anti = cast<Instruction>(addMalloc(bb, anti));
invertedPointers[call] = anti;

if (tape == nullptr) {
Expand All @@ -288,26 +288,34 @@ class GradientUtils {
return anti;
}

template<typename T>
T* addMalloc(IRBuilder<> &BuilderQ, T* malloc) {
Value* addMalloc(IRBuilder<> &BuilderQ, Value* malloc) {
if (tape) {
if (!tape->getType()->isStructTy()) {
llvm::errs() << "addMalloc incorrect tape type: " << *tape << "\n";
}
assert(tape->getType()->isStructTy());
if (tapeidx >= cast<StructType>(tape->getType())->getNumElements()) {
llvm::errs() << "oldFunc: " <<*oldFunc << "\n";
llvm::errs() << "newFunc: " <<*newFunc << "\n";
if (malloc)
llvm::errs() << "malloc: " <<*malloc << "\n";
llvm::errs() << "tape: " <<*tape << "\n";
llvm::errs() << "tapeidx: " << tapeidx << "\n";
}
assert(tapeidx < cast<StructType>(tape->getType())->getNumElements());
Instruction* ret = cast<Instruction>(BuilderQ.CreateExtractValue(tape, {tapeidx}));
Instruction* origret = ret;
tapeidx++;

if (ret->getType()->isEmptyTy()) {
/*
if (auto inst = dyn_cast<Instruction>(malloc)) {

if (auto inst = dyn_cast_or_null<Instruction>(malloc)) {
inst->replaceAllUsesWith(UndefValue::get(ret->getType()));
erase(inst);
}
*/
return ret;
//UndefValue::get(ret->getType());

//return ret;
return UndefValue::get(ret->getType());
}

BasicBlock* parent = BuilderQ.GetInsertBlock();
Expand Down Expand Up @@ -503,14 +511,11 @@ class GradientUtils {
erase(cast<Instruction>(malloc));
ret->setName(n);
}
llvm::errs() << " retrieved from malloc " << *ret << "\n";
return ret;
} else {
assert(malloc);
assert(!isa<PHINode>(malloc));

llvm::errs() << " adding to malloc " << *malloc << "\n";

if (isa<UndefValue>(malloc)) {
addedMallocs.push_back(malloc);
return malloc;
Expand Down Expand Up @@ -1230,6 +1235,7 @@ class DiffeGradientUtils : public GradientUtils {
differentials[val] = entryBuilder.CreateAlloca(val->getType(), nullptr, val->getName()+"'de");
entryBuilder.CreateStore(Constant::getNullValue(val->getType()), differentials[val]);
}
assert(cast<PointerType>(differentials[val]->getType())->getElementType() == val->getType());
return differentials[val];
}

Expand Down Expand Up @@ -1320,7 +1326,13 @@ class DiffeGradientUtils : public GradientUtils {
llvm::errs() << *val << "\n";
}
assert(!isConstantValue(val));
BuilderM.CreateStore(toset, getDifferential(val));
Value* tostore = getDifferential(val);
if (toset->getType() != cast<PointerType>(tostore->getType())->getElementType()) {
llvm::errs() << "toset:" << *toset << "\n";
llvm::errs() << "tostore:" << *tostore << "\n";
}
assert(toset->getType() == cast<PointerType>(tostore->getType())->getElementType());
BuilderM.CreateStore(toset, tostore);
}

SelectInst* addToDiffeIndexed(Value* val, Value* dif, ArrayRef<Value*> idxs, IRBuilder<> &BuilderM) {
Expand Down Expand Up @@ -1378,9 +1390,9 @@ class DiffeGradientUtils : public GradientUtils {
BuilderM.CreateStore(res, ptr);
}

void setPtrDiffe(Value* ptr, Value* newval, IRBuilder<> &BuilderM) {
StoreInst* setPtrDiffe(Value* ptr, Value* newval, IRBuilder<> &BuilderM) {
ptr = invertPointerM(ptr, BuilderM);
BuilderM.CreateStore(newval, ptr);
return BuilderM.CreateStore(newval, ptr);
}

};
Expand Down
5 changes: 4 additions & 1 deletion enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ static inline bool hasMetadata(const llvm::GlobalObject* O, llvm::StringRef kind
}

enum class ReturnType {
ArgsWithReturn, Args, TapeAndReturns
ArgsWithReturn,
Args,
TapeAndReturns,
Tape,
};

enum class DIFFE_TYPE {
Expand Down
85 changes: 85 additions & 0 deletions enzyme/test/Enzyme/badcall.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%arrayidx = getelementptr inbounds double, double* %x, i64 1
store double 3.000000e+00, double* %arrayidx, align 8
%0 = load double, double* %x, align 8
%cmp = fcmp fast oeq double %0, 2.000000e+00
ret i1 %cmp
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%0 = load double, double* %x, align 8
%mul = fmul fast double %0, 2.000000e+00
store double %mul, double* %x, align 8
%call = tail call zeroext i1 @metasubf(double* %x)
ret i1 %call
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local void @f(double* nocapture %x) #0 {
entry:
%call = tail call zeroext i1 @subf(double* %x)
store double 2.000000e+00, double* %x, align 8
ret void
}

; Function Attrs: noinline nounwind uwtable
define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 {
entry:
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp)
ret double %call
}

declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed_addr

attributes #0 = { noinline norecurse nounwind uwtable }
attributes #1 = { noinline nounwind uwtable }

; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { { {} } } @augmented_subf(double* %x, double* %"x'")
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef)
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal { {} } @augmented_metasubf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
; CHECK-NEXT: ret { {} } undef
; CHECK-NEXT: }

; CHECK: define internal { { {} } } @augmented_subf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = load double, double* %x, align 8
; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00
; CHECK-NEXT: store double %mul, double* %x, align 8
; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'")
; CHECK-NEXT: ret { { {} } } undef
; CHECK-NEXT: }

; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
; CHECK-NEXT: %1 = load double, double* %"x'"
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %m0diffe = fmul fast double %1, 2.000000e+00
; CHECK-NEXT: %2 = load double, double* %"x'"
; CHECK-NEXT: %3 = fadd fast double %2, %m0diffe
; CHECK-NEXT: store double %3, double* %"x'"
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal {} @diffemetasubf(double* nocapture %x, double* %"x'", {} %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[tostore:.+]] = getelementptr inbounds double, double* %"x'", i64 1
; CHECK-NEXT: store double 0.000000e+00, double* %[[tostore]], align 8
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }
109 changes: 109 additions & 0 deletions enzyme/test/Enzyme/badcall2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%arrayidx = getelementptr inbounds double, double* %x, i64 1
store double 3.000000e+00, double* %arrayidx, align 8
%0 = load double, double* %x, align 8
%cmp = fcmp fast oeq double %0, 2.000000e+00
ret i1 %cmp
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local zeroext i1 @othermetasubf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%arrayidx = getelementptr inbounds double, double* %x, i64 1
store double 4.000000e+00, double* %arrayidx, align 8
%0 = load double, double* %x, align 8
%cmp = fcmp fast oeq double %0, 3.000000e+00
ret i1 %cmp
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 {
entry:
%0 = load double, double* %x, align 8
%mul = fmul fast double %0, 2.000000e+00
store double %mul, double* %x, align 8
%call = tail call zeroext i1 @metasubf(double* %x)
%call1 = tail call zeroext i1 @othermetasubf(double* %x)
ret i1 %call1
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local void @f(double* nocapture %x) #0 {
entry:
%call = tail call zeroext i1 @subf(double* %x)
store double 2.000000e+00, double* %x, align 8
ret void
}

; Function Attrs: noinline nounwind uwtable
define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 {
entry:
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp)
ret double %call
}

declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed_addr

; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { { {}, {} } } @augmented_subf(double* %x, double* %"x'")
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8
; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef)
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal { {} } @augmented_othermetasubf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
; CHECK-NEXT: store double 4.000000e+00, double* %arrayidx, align 8
; CHECK-NEXT: ret { {} } undef
; CHECK-NEXT: }

; CHECK: define internal { {} } @augmented_metasubf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
; CHECK-NEXT: ret { {} } undef
; CHECK-NEXT: }

; CHECK: define internal { { {}, {} } } @augmented_subf(double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = load double, double* %x, align 8
; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00
; CHECK-NEXT: store double %mul, double* %x, align 8
; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'")
; CHECK-NEXT: %2 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'")
; CHECK-NEXT: ret { { {}, {} } } undef
; CHECK-NEXT: }

; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call {} @diffeothermetasubf(double* %x, double* %"x'", {} undef)
; CHECK-NEXT: %1 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
; CHECK-NEXT: %2 = load double, double* %"x'"
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
; CHECK-NEXT: %m0diffe = fmul fast double %2, 2.000000e+00
; CHECK-NEXT: %3 = load double, double* %"x'"
; CHECK-NEXT: %4 = fadd fast double %3, %m0diffe
; CHECK-NEXT: store double %4, double* %"x'"
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal {} @diffeothermetasubf(double* nocapture %x, double* %"x'", {} %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[tostore:.+]] = getelementptr inbounds double, double* %"x'", i64 1
; CHECK-NEXT: store double 0.000000e+00, double* %[[tostore]], align 8
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal {} @diffemetasubf(double* nocapture %x, double* %"x'", {} %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %[[tostore2:.+]] = getelementptr inbounds double, double* %"x'", i64 1
; CHECK-NEXT: store double 0.000000e+00, double* %[[tostore2]], align 8
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }
Loading