Skip to content

Commit

Permalink
Fix the incredibly stupid plug AssertingVH bug (#4)
Browse files Browse the repository at this point in the history
* start canonical fixups and call it a night

* Horrible horrible hacks

* Working tests

* working LLVM 6?

* don't forget to prepend fake

* Fix default behavior for unknown constant intrinsic

* cleanup prints
  • Loading branch information
wsmoses committed May 17, 2021
1 parent 3dd8c13 commit fcbe28c
Show file tree
Hide file tree
Showing 12 changed files with 4,930 additions and 197 deletions.
6 changes: 5 additions & 1 deletion enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ project(Enzyme)
SET(CMAKE_CXX_FLAGS "-Wall")
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g")
SET(CMAKE_CXX_FLAGS_RELEASE "-O2")
SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g")

SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer")

#SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer -fsanitize=address")
#SET(CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address")

set(CMAKE_CXX_STANDARD 11)
cmake_minimum_required(VERSION 3.5)
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
if (${LLVM_VERSION_MAJOR} LESS 8)
add_llvm_loadable_module( LLVMEnzyme-${LLVM_VERSION_MAJOR}
Enzyme.cpp
SCEV/ScalarEvolutionExpander.cpp
DEPENDS
intrinsics_gen
PLUGIN_TOOL
Expand All @@ -13,6 +14,7 @@ if (${LLVM_VERSION_MAJOR} LESS 8)
else()
add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR}
Enzyme.cpp
SCEV/ScalarEvolutionExpander.cpp
MODULE
DEPENDS
intrinsics_gen
Expand Down
159 changes: 49 additions & 110 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include <llvm/Config/llvm-config.h>

#include "SCEV/ScalarEvolutionExpander.h"

#include "llvm/Transforms/Utils/PromoteMemToReg.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Scalar/GVN.h"
Expand Down Expand Up @@ -827,7 +829,11 @@ void forceRecursiveInlining(Function *NewF, const Function* F) {
}
}

Function* preprocessForClone(Function *F, AAResults &AA) {
class GradientUtils;

PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils* gutils);

Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) {
static std::map<Function*,Function*> cache;
if (cache.find(F) != cache.end()) return cache[F];

Expand Down Expand Up @@ -1071,7 +1077,7 @@ Function* preprocessForClone(Function *F, AAResults &AA) {
DSEPass().run(*NewF, AM);
LoopSimplifyPass().run(*NewF, AM);

}
}

if (autodiff_print)
llvm::errs() << "after simplification :\n" << *NewF << "\n";
Expand All @@ -1084,9 +1090,9 @@ Function* preprocessForClone(Function *F, AAResults &AA) {
return NewF;
}

Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, ValueToValueMapTy& ptrInputs, const std::set<unsigned>& constant_args, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, SmallPtrSetImpl<Value*> &returnvals, ReturnType returnValue, bool differentialReturn, Twine name, ValueToValueMapTy *VMapO, bool diffeReturnArg, llvm::Type* additionalArg = nullptr) {
Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, TargetLibraryInfo &TLI, ValueToValueMapTy& ptrInputs, const std::set<unsigned>& constant_args, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, SmallPtrSetImpl<Value*> &returnvals, ReturnType returnValue, bool differentialReturn, Twine name, ValueToValueMapTy *VMapO, bool diffeReturnArg, llvm::Type* additionalArg = nullptr) {
assert(!F->empty());
F = preprocessForClone(F, AA);
F = preprocessForClone(F, AA, TLI);
diffeReturnArg &= differentialReturn;
std::vector<Type*> RetTypes;
if (returnValue == ReturnType::ArgsWithReturn)
Expand Down Expand Up @@ -1257,15 +1263,6 @@ Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, ValueToValueMapT
#include "llvm/IR/Constant.h"
#include <deque>
#include "llvm/IR/CFG.h"
class GradientUtils;

PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils *gutils);

/// \brief Replace the latch of the loop to check that IV is always less than or
/// equal to the limit.
///
/// This method assumes that the loop has a single loop latch.
Value* canonicalizeLoopLatch(PHINode *IV, Value *Limit, Loop* L, ScalarEvolution &SE, BasicBlock* ExitBlock, GradientUtils *gutils);

bool shouldRecompute(Value* val, const ValueToValueMapTy& available) {
if (available.count(val)) return false;
Expand Down Expand Up @@ -1943,7 +1940,7 @@ class GradientUtils {
SmallPtrSet<Value*,20> nonconstant;
SmallPtrSet<Value*,2> returnvals;
ValueToValueMapTy originalToNew;
auto newFunc = CloneFunctionWithReturns(todiff, AA, invertedPointers, constant_args, constants, nonconstant, returnvals, /*returnValue*/returnValue, /*differentialReturn*/differentialReturn, "fakeaugmented_"+todiff->getName(), &originalToNew, /*diffeReturnArg*/false, additionalArg);
auto newFunc = CloneFunctionWithReturns(todiff, AA, TLI, invertedPointers, constant_args, constants, nonconstant, returnvals, /*returnValue*/returnValue, /*differentialReturn*/differentialReturn, "fakeaugmented_"+todiff->getName(), &originalToNew, /*diffeReturnArg*/false, additionalArg);
auto res = new GradientUtils(newFunc, AA, TLI, invertedPointers, constants, nonconstant, returnvals, originalToNew);
res->oldFunc = todiff;
return res;
Expand Down Expand Up @@ -2230,7 +2227,6 @@ class GradientUtils {
IRBuilder <> v(putafter);
v.setFastMathFlags(getFast());
v.CreateStore(inst, scopeMap[inst]);
llvm::errs() << " place foo\n"; dumpSet(originalInstructions);
} else {

ValueToValueMapTy valmap;
Expand Down Expand Up @@ -2722,7 +2718,7 @@ class DiffeGradientUtils : public GradientUtils {
SmallPtrSet<Value*,20> nonconstant;
SmallPtrSet<Value*,2> returnvals;
ValueToValueMapTy originalToNew;
auto newFunc = CloneFunctionWithReturns(todiff, AA, invertedPointers, constant_args, constants, nonconstant, returnvals, returnValue, differentialReturn, "diffe"+todiff->getName(), &originalToNew, /*diffeReturnArg*/true, additionalArg);
auto newFunc = CloneFunctionWithReturns(todiff, AA, TLI, invertedPointers, constant_args, constants, nonconstant, returnvals, returnValue, differentialReturn, "diffe"+todiff->getName(), &originalToNew, /*diffeReturnArg*/true, additionalArg);
auto res = new DiffeGradientUtils(newFunc, AA, TLI, invertedPointers, constants, nonconstant, returnvals, originalToNew);
res->oldFunc = todiff;
return res;
Expand Down Expand Up @@ -3062,6 +3058,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
case Intrinsic::cos:
break;
default:
if (gutils->isConstantInstruction(inst)) continue;
assert(inst);
llvm::errs() << "cannot handle (augmented) unknown intrinsic\n" << *inst;
report_fatal_error("(augmented) unknown intrinsic");
Expand Down Expand Up @@ -3695,7 +3692,7 @@ std::pair<SmallVector<Type*,4>,SmallVector<Type*,4>> getDefaultFunctionTypeForGr
return std::pair<SmallVector<Type*,4>,SmallVector<Type*,4>>(args, outs);
}

Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) {
Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) {
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions;
auto tup = std::make_tuple(todiff, std::set<unsigned>(constant_args.begin(), constant_args.end()), returnValue, differentialReturn, topLevel, additionalArg);
if (cachedfunctions.find(tup) != cachedfunctions.end()) {
Expand Down Expand Up @@ -3783,10 +3780,10 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co

DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg);
cachedfunctions[tup] = gutils->newFunc;

gutils->forceContexts();
gutils->forceAugmentedReturns();

Argument* additionalValue = nullptr;
if (additionalArg) {
auto v = gutils->newFunc->arg_end();
Expand All @@ -3810,7 +3807,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co

std::map<ReturnInst*,StoreInst*> replacedReturns;


for(BasicBlock* BB: gutils->originalBlocks) {

LoopContext loopContext;
Expand Down Expand Up @@ -4141,6 +4137,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
break;
}
default:
if (gutils->isConstantInstruction(inst)) continue;
assert(inst);
llvm::errs() << "cannot handle unknown intrinsic\n" << *inst;
report_fatal_error("unknown intrinsic");
Expand Down Expand Up @@ -4239,7 +4236,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
if (auto dc = dyn_cast<CallInst>(val)) {
if (dc->getCalledFunction()->getName() == "malloc") {
gutils->erase(op);
llvm::errs() << " place free\n"; dumpSet(gutils->originalInstructions);
continue;
}
}
Expand Down Expand Up @@ -4903,6 +4899,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
}

void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, LoopInfo& LI, DominatorTree& DT) {


Value* fn = CI->getArgOperand(0);

while (auto ci = dyn_cast<CastInst>(fn)) {
Expand All @@ -4916,7 +4914,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
}
auto FT = cast<Function>(fn)->getFunctionType();
assert(fn);

if (autodiff_print)
llvm::errs() << "prefn:\n" << *fn << "\n";

Expand Down Expand Up @@ -5006,6 +5004,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
}

bool differentialReturn = cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();

auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr);//, LI, DT);

if (differentialReturn)
Expand All @@ -5029,13 +5028,14 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
}

static bool lowerAutodiffIntrinsic(Function &F, TargetLibraryInfo &TLI, AAResults &AA) {//, LoopInfo& LI, DominatorTree& DT) {

bool Changed = false;

reset:
for (BasicBlock &BB : F) {

for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE;) {
Instruction *Inst = &*BI++;
CallInst *CI = dyn_cast_or_null<CallInst>(Inst);
for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE; BI++) {
CallInst *CI = dyn_cast<CallInst>(&*BI);
if (!CI) continue;

Function *Fn = CI->getCalledFunction();
Expand All @@ -5049,103 +5049,33 @@ static bool lowerAutodiffIntrinsic(Function &F, TargetLibraryInfo &TLI, AAResult
if (Fn && ( Fn->getName() == "__enzyme_autodiff" || Fn->getName().startswith("__enzyme_autodiff")) ) {
HandleAutoDiff(CI, TLI, AA);//, LI, DT);
Changed = true;
goto reset;
}
}
}

return Changed;
}

PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils *gutils) {
//PHINode* pn = L->getCanonicalInductionVariable();
//assert( pn && "canonical IV");
//return pn;

PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils* gutils) {

PHINode *CanonicalIV;

/*
{
SCEVExpander e(SE, L->getHeader()->getParent()->getParent()->getDataLayout(), "ad");
assert(Ty->isIntegerTy() && "Can only insert integer induction variables!");
// Build a SCEV for {0,+,1}<L>.
// Conservatively use FlagAnyWrap for now.
const SCEV *H = SE.getAddRecExpr(SE.getConstant(Ty, 0),
SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap);
// Emit code for it.
e.setInsertPoint(&L->getHeader()->front());
Value *V = e.expand(H);
CanonicalIV = cast<PHINode>(V); //e.expandCodeFor(H, nullptr));
}
*/

BasicBlock* Header = L->getHeader();
Module* M = Header->getParent()->getParent();
const DataLayout &DL = M->getDataLayout();
SmallVector<Instruction*, 16> DeadInsts;

{
SCEVExpander Exp(SE, DL, "ad");

CanonicalIV = Exp.getOrInsertCanonicalInductionVariable(L, Ty);
fake::SCEVExpander e(SE, L->getHeader()->getParent()->getParent()->getDataLayout(), "ad");

PHINode *CanonicalIV = e.getOrInsertCanonicalInductionVariable(L, Ty);

assert (CanonicalIV && "canonicalizing IV");

assert (CanonicalIV && "canonicalizing IV");
//DEBUG(dbgs() << "Canonical induction variable " << *CanonicalIV << "\n");

SmallVector<WeakTrackingVH, 16> DeadInst0;
Exp.replaceCongruentIVs(L, &DT, DeadInst0);
e.replaceCongruentIVs(L, &DT, DeadInst0);
for (WeakTrackingVH V : DeadInst0) {
//DeadInsts.push_back(cast<Instruction>(V));
}

gutils->erase(cast<Instruction>(V)); //->eraseFromParent();
}

for (Instruction* I : DeadInsts) {
if (gutils) gutils->erase(I);
}

return CanonicalIV;

}

Value* canonicalizeLoopLatch(PHINode *IV, Value *Limit, Loop* L, ScalarEvolution &SE, BasicBlock* ExitBlock, GradientUtils *gutils) {
Value *NewCondition;
BasicBlock *Header = L->getHeader();
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "No single loop latch found for loop.");

IRBuilder<> Builder(&*Latch->getFirstInsertionPt());
Builder.setFastMathFlags(getFast());

// This process assumes that IV's increment is in Latch.

// Create comparison between IV and Limit at top of Latch.
NewCondition = Builder.CreateICmpULT(IV, Limit);

// Replace the conditional branch at the end of Latch.
BranchInst *LatchBr = dyn_cast_or_null<BranchInst>(Latch->getTerminator());
assert(LatchBr && LatchBr->isConditional() &&
"Latch does not terminate with a conditional branch.");
Builder.SetInsertPoint(Latch->getTerminator());
Builder.CreateCondBr(NewCondition, Header, ExitBlock);

// Erase the old conditional branch.
Value *OldCond = LatchBr->getCondition();
if (gutils) gutils->erase(LatchBr);

if (!OldCond->hasNUsesOrMore(1))
if (Instruction *OldCondInst = dyn_cast<Instruction>(OldCond)) {
if (gutils) gutils->erase(OldCondInst);
}


return NewCondition;
}

bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopContext> &loopContexts, LoopInfo &LI,ScalarEvolution &SE,DominatorTree &DT, GradientUtils &gutils) {
if (auto L = LI.getLoopFor(BB)) {
if (loopContexts.find(L) != loopContexts.end()) {
Expand Down Expand Up @@ -5234,13 +5164,10 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
CanonicalSCEV, Limit) &&
"Loop backedge is not guarded by canonical comparison with limit.");

SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
fake::SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
LimitVar = Exp.expandCodeFor(Limit, CanonicalIV->getType(),
Preheader->getTerminator());

// Canonicalize the loop latch.
canonicalizeLoopLatch(CanonicalIV, LimitVar, L, SE, ExitBlock, &gutils);

loopContext.dynamic = false;
} else {

Expand Down Expand Up @@ -5273,7 +5200,7 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo

// Remove Canonicalizable IV's
{
SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
fake::SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
for (BasicBlock::iterator II = Header->begin(); isa<PHINode>(II); ++II) {
PHINode *PN = cast<PHINode>(II);
if (PN == CanonicalIV) continue;
Expand Down Expand Up @@ -5336,12 +5263,24 @@ class Enzyme : public FunctionPass {
AU.addRequired<AAResultsWrapperPass>();
AU.addRequired<GlobalsAAWrapperPass>();
AU.addRequiredID(LoopSimplifyID);
AU.addRequiredID(LCSSAID);
//AU.addRequiredID(LCSSAID);

AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<ScalarEvolutionWrapperPass>();
}

bool runOnFunction(Function &F) override {
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();

/*
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
*/


return lowerAutodiffIntrinsic(F, TLI, AA);
}
};
Expand Down
Loading

0 comments on commit fcbe28c

Please sign in to comment.