Skip to content

Commit

Permalink
add more caching layers to speed compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 12, 2019
1 parent a8bdf3b commit 7743e46
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 71 deletions.
7 changes: 1 addition & 6 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1593,13 +1593,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
//}
Builder2.setFastMathFlags(getFast());

std::map<Value*,Value*> alreadyLoaded;

std::function<Value*(Value*)> lookup = [&](Value* val) -> Value* {
if (alreadyLoaded.find(val) != alreadyLoaded.end()) {
return alreadyLoaded[val];
}
return alreadyLoaded[val] = gutils->lookupM(val, Builder2);
return gutils->lookupM(val, Builder2);
};

auto diffe = [&Builder2,&gutils](Value* val) -> Value* {
Expand Down
9 changes: 9 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ void removeRedundantIVs(const Loop* L, BasicBlock* Header, BasicBlock* Preheader
assert(cmp->getOperand(0) == increment);

auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
llvm::errs() << "coing to think about " << *cmp << "\n";
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {

// valid replacements (since unsigned comparison and i starts at 0 counting up)
Expand Down Expand Up @@ -795,6 +796,13 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) {
val = inst = fixLCSSA(inst, BuilderM);

assert(!this->isOriginalBlock(*BuilderM.GetInsertBlock()));

static std::map<std::pair<Value*, BasicBlock*>, Value*> cache;
auto idx = std::make_pair(val, BuilderM.GetInsertBlock());
if (cache.find(idx) != cache.end()) {
return cache[idx];
}

LoopContext lc;
bool inLoop = getContext(inst->getParent(), lc);

Expand Down Expand Up @@ -826,6 +834,7 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) {
assert(scopeMap[inst]);
Value* result = lookupValueFromCache(BuilderM, inst->getParent(), scopeMap[inst]);
assert(result->getType() == inst->getType());
cache[idx] = result;
return result;
}

Expand Down
167 changes: 120 additions & 47 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,63 +717,131 @@ class GradientUtils {
}

Value* unwrapM(Value* val, IRBuilder<>& BuilderM, const ValueToValueMapTy& available, bool lookupIfAble) {
assert(val);
assert(val);

static std::map<std::pair<Value*, BasicBlock*>, Value*> cache;
auto cidx = std::make_pair(val, BuilderM.GetInsertBlock());
if (cache.find(cidx) != cache.end()) {
return cache[cidx];
}

if (available.count(val)) {
return available.lookup(val);
}

if (auto inst = dyn_cast<Instruction>(val)) {
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
if (BuilderM.GetInsertBlock()->size() && BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) {
if (DT.dominates(inst, &*BuilderM.GetInsertPoint())) {
//llvm::errs() << "allowed " << *inst << "from domination\n";
return inst;
}
} else {
if (DT.dominates(inst, BuilderM.GetInsertBlock())) {
//llvm::errs() << "allowed " << *inst << "from block domination\n";
return inst;
}
}
}
}

if (isa<Argument>(val) || isa<Constant>(val)) {
cache[std::make_pair(val, BuilderM.GetInsertBlock())] = val;
return val;
} else if (isa<AllocaInst>(val)) {
cache[std::make_pair(val, BuilderM.GetInsertBlock())] = val;
return val;
} else if (auto op = dyn_cast<CastInst>(val)) {
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
if (op0 == nullptr) goto endCheck;
return BuilderM.CreateCast(op->getOpcode(), op0, op->getDestTy(), op->getName()+"_unwrap");
auto toreturn = BuilderM.CreateCast(op->getOpcode(), op0, op->getDestTy(), op->getName()+"_unwrap");
if (cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto op = dyn_cast<ExtractValueInst>(val)) {
auto op0 = unwrapM(op->getAggregateOperand(), BuilderM, available, lookupIfAble);
if (op0 == nullptr) goto endCheck;
return BuilderM.CreateExtractValue(op0, op->getIndices(), op->getName()+"_unwrap");
auto toreturn = BuilderM.CreateExtractValue(op0, op->getIndices(), op->getName()+"_unwrap");
if (cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto op = dyn_cast<BinaryOperator>(val)) {
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
if (op0 == nullptr) goto endCheck;
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
if (op1 == nullptr) goto endCheck;
return BuilderM.CreateBinOp(op->getOpcode(), op0, op1);
auto toreturn = BuilderM.CreateBinOp(op->getOpcode(), op0, op1);
if (
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) ) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto op = dyn_cast<ICmpInst>(val)) {
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
if (op0 == nullptr) goto endCheck;
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
if (op1 == nullptr) goto endCheck;
return BuilderM.CreateICmp(op->getPredicate(), op0, op1);
auto toreturn = BuilderM.CreateICmp(op->getPredicate(), op0, op1);
if (
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) ) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto op = dyn_cast<FCmpInst>(val)) {
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
if (op0 == nullptr) goto endCheck;
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
if (op1 == nullptr) goto endCheck;
return BuilderM.CreateFCmp(op->getPredicate(), op0, op1);
auto toreturn = BuilderM.CreateFCmp(op->getPredicate(), op0, op1);
if (
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) ) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto op = dyn_cast<SelectInst>(val)) {
auto op0 = unwrapM(op->getOperand(0), BuilderM, available, lookupIfAble);
if (op0 == nullptr) goto endCheck;
auto op1 = unwrapM(op->getOperand(1), BuilderM, available, lookupIfAble);
if (op1 == nullptr) goto endCheck;
auto op2 = unwrapM(op->getOperand(2), BuilderM, available, lookupIfAble);
if (op2 == nullptr) goto endCheck;
return BuilderM.CreateSelect(op0, op1, op2);
auto toreturn = BuilderM.CreateSelect(op0, op1, op2);
if (
(cache.find(std::make_pair((Value*)op->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) &&
(cache.find(std::make_pair((Value*)op->getOperand(1), BuilderM.GetInsertBlock())) != cache.end()) &&
(cache.find(std::make_pair((Value*)op->getOperand(2), BuilderM.GetInsertBlock())) != cache.end()) ) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto inst = dyn_cast<GetElementPtrInst>(val)) {
auto ptr = unwrapM(inst->getPointerOperand(), BuilderM, available, lookupIfAble);
if (ptr == nullptr) goto endCheck;
bool cached = cache.find(std::make_pair(inst->getPointerOperand(), BuilderM.GetInsertBlock())) != cache.end();
SmallVector<Value*,4> ind;
for(auto& a : inst->indices()) {
auto op = unwrapM(a, BuilderM,available, lookupIfAble);
if (op == nullptr) goto endCheck;
cached &= cache.find(std::make_pair((Value*)a, BuilderM.GetInsertBlock())) != cache.end();
ind.push_back(op);
}
return BuilderM.CreateGEP(ptr, ind);
auto toreturn = BuilderM.CreateGEP(ptr, ind, inst->getName() + "_unwrap");
if (cached) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto load = dyn_cast<LoadInst>(val)) {
Value* idx = unwrapM(load->getOperand(0), BuilderM, available, lookupIfAble);
if (idx == nullptr) goto endCheck;
return BuilderM.CreateLoad(idx);
auto toreturn = BuilderM.CreateLoad(idx);
if (cache.find(std::make_pair((Value*)load->getOperand(0), BuilderM.GetInsertBlock())) != cache.end()) {
cache[cidx] = toreturn;
}
return toreturn;
} else if (auto op = dyn_cast<IntrinsicInst>(val)) {
switch(op->getIntrinsicID()) {
case Intrinsic::sin: {
Expand Down Expand Up @@ -839,7 +907,6 @@ class GradientUtils {
if (!inLoop) {
return entryBuilder.CreateAlloca(T, nullptr, name+"_cache");
} else {
Value* size = nullptr;

BasicBlock* outermostPreheader = nullptr;

Expand All @@ -853,38 +920,45 @@ class GradientUtils {

IRBuilder <> allocationBuilder(&outermostPreheader->back());

for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) {
//TODO handle allocations for dynamic loops
if (idx.dynamic && idx.parent != nullptr) {
assert(idx.var);
assert(idx.var->getParent());
assert(idx.var->getParent()->getParent());
llvm::errs() << *idx.var->getParent()->getParent() << "\n"
<< "idx.var=" <<*idx.var << "\n"
<< "idx.limit=" <<*idx.limit << "\n";
llvm::errs() << "cannot handle non-outermost dynamic loop\n";
assert(0 && "cannot handle non-outermost dynamic loop");
}
Value* ns = nullptr;
Type* intT = idx.dynamic ? cast<PointerType>(idx.limit->getType())->getElementType() : idx.limit->getType();
if (idx.dynamic) {
ns = ConstantInt::get(intT, 1);
} else {
Value* limitm1 = nullptr;
ValueToValueMapTy emptyMap;
limitm1 = unwrapM(idx.limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);
if (limitm1 == nullptr) {
assert(outermostPreheader);
assert(outermostPreheader->getParent());
llvm::errs() << *outermostPreheader->getParent() << "\n";
llvm::errs() << "needed value " << *idx.limit << " at " << allocationBuilder.GetInsertBlock()->getName() << "\n";
Value* size = nullptr;
static std::map<BasicBlock*, Value*> sizecache;
if (sizecache.find(lc.header) != sizecache.end()) {
size = sizecache[lc.header];
} else {
for(LoopContext idx = lc; ; getContext(idx.parent->getHeader(), idx) ) {
//TODO handle allocations for dynamic loops
if (idx.dynamic && idx.parent != nullptr) {
assert(idx.var);
assert(idx.var->getParent());
assert(idx.var->getParent()->getParent());
llvm::errs() << *idx.var->getParent()->getParent() << "\n"
<< "idx.var=" <<*idx.var << "\n"
<< "idx.limit=" <<*idx.limit << "\n";
llvm::errs() << "cannot handle non-outermost dynamic loop\n";
assert(0 && "cannot handle non-outermost dynamic loop");
}
Value* ns = nullptr;
Type* intT = idx.dynamic ? cast<PointerType>(idx.limit->getType())->getElementType() : idx.limit->getType();
if (idx.dynamic) {
ns = ConstantInt::get(intT, 1);
} else {
Value* limitm1 = nullptr;
ValueToValueMapTy emptyMap;
limitm1 = unwrapM(idx.limit, allocationBuilder, emptyMap, /*lookupIfAble*/false);
if (limitm1 == nullptr) {
assert(outermostPreheader);
assert(outermostPreheader->getParent());
llvm::errs() << *outermostPreheader->getParent() << "\n";
llvm::errs() << "needed value " << *idx.limit << " at " << allocationBuilder.GetInsertBlock()->getName() << "\n";
}
assert(limitm1);
ns = allocationBuilder.CreateNUWAdd(limitm1, ConstantInt::get(intT, 1));
}
if (size == nullptr) size = ns;
else size = allocationBuilder.CreateNUWMul(size, ns);
if (idx.parent == nullptr) break;
}
assert(limitm1);
ns = allocationBuilder.CreateNUWAdd(limitm1, ConstantInt::get(intT, 1));
}
if (size == nullptr) size = ns;
else size = allocationBuilder.CreateNUWMul(size, ns);
if (idx.parent == nullptr) break;
sizecache[lc.header] = size;
}

auto firstallocation = CallInst::CreateMalloc(
Expand Down Expand Up @@ -955,6 +1029,7 @@ class GradientUtils {
limits.push_back(lim);
}

/*
Value* idx = nullptr;
for(unsigned i=0; i<indices.size(); i++) {
if (i == 0) {
Expand All @@ -963,20 +1038,18 @@ class GradientUtils {
auto mul = v.CreateNUWMul(indices[i], limits[i-1]);
idx = v.CreateNUWAdd(idx, mul);
}
}
}*/

if (dynamicPHI != nullptr) {
Type *BPTy = Type::getInt8PtrTy(v.GetInsertBlock()->getContext());
auto realloc = newFunc->getParent()->getOrInsertFunction("realloc", BPTy, BPTy, size->getType());
Value* allocation = v.CreateLoad(holderAlloc);
auto foo = v.CreateNUWAdd(dynamicPHI, ConstantInt::get(dynamicPHI->getType(), 1));
Value* foo = v.CreateNUWAdd(dynamicPHI, ConstantInt::get(dynamicPHI->getType(), 1));
Value* realloc_size = v.CreateNUWMul(size, foo);
Value* idxs[2] = {
v.CreatePointerCast(allocation, BPTy),
v.CreateNUWMul(
ConstantInt::get(size->getType(), newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T)/8),
v.CreateNUWMul(
size, foo
)
ConstantInt::get(size->getType(), newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(T)/8), realloc_size
)
};

Expand Down
11 changes: 5 additions & 6 deletions enzyme/test/Enzyme/cppllist.ll
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,14 @@ attributes #8 = { builtin nounwind }
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %n, %[[invertdelete]] ], [ %[[isub:.+]], %invertfor.body.i ]
; CHECK-NEXT: %[[isub]] = add i64 %[[antivar]], -1
; CHECK-NEXT: %[[gepiv:.+]] = getelementptr i8*, i8** %"call'mi_malloccache.i", i64 %[[antivar]]
; CHECK-NEXT: %[[bcast:.+]] = bitcast i8** %[[gepiv]] to double**
; CHECK-NEXT: %[[metaload:.+]] = load double*, double** %[[bcast]]
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[metaload]]
; CHECK-NEXT: %[[metaload:.+]] = load i8*, i8** %[[gepiv]]
; CHECK-NEXT: %[[bcast:.+]] = bitcast i8* %[[metaload]] to double*
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[bcast]]
; this store is optional and could get removed by DCE
; CHECK-NEXT: store double 0.000000e+00, double* %[[metaload]]
; CHECK-NEXT: store double 0.000000e+00, double* %[[bcast]]
; CHECK-NEXT: %[[xadd]] = fadd fast double %"x'de.0.i", %[[load]]
; this reload really should be eliminated
; CHECK-NEXT: %[[recallpload2free:.+]] = load i8*, i8** %[[gepiv]]
; CHECK-NEXT: call void @_ZdlPv(i8* nonnull %[[recallpload2free]]) #5
; CHECK-NEXT: call void @_ZdlPv(i8* nonnull %[[metaload]]) #5
; CHECK-NEXT: %[[heregep:.+]] = getelementptr i8*, i8** %call_malloccache.i, i64 %[[antivar]]
; CHECK-NEXT: %[[callload2free:.+]] = load i8*, i8** %[[heregep]]
; CHECK-NEXT: call void @_ZdlPv(i8* %[[callload2free]]) #5
Expand Down
11 changes: 5 additions & 6 deletions enzyme/test/Enzyme/initializemany.ll
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,12 @@ attributes #4 = { nounwind }
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %wide.trip.count, %entry ], [ %[[sub:.+]], %invertfor.body ]
; CHECK-NEXT: %[[sub]] = add i64 %[[antivar]], -1
; CHECK-NEXT: %[[geper:.+]] = getelementptr i8*, i8** %0, i64 %[[sub]]
; CHECK-NEXT: %[[bc:.+]] = bitcast i8** %[[geper]] to double**
; CHECK-NEXT: %[[metaload:.+]] = load double*, double** %[[bc]], align 8
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[metaload]], align 8
; CHECK-NEXT: store double 0.000000e+00, double* %[[metaload]], align 8
; CHECK-NEXT: %[[metaload:.+]] = load i8*, i8** %[[geper]], align 8
; CHECK-NEXT: %[[bc:.+]] = bitcast i8* %[[metaload]] to double*
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[bc]], align 8
; CHECK-NEXT: store double 0.000000e+00, double* %[[bc]], align 8
; CHECK-NEXT: %[[added]] = fadd fast double %"x'de.0", %[[load]]
; CHECK-NEXT: %[[tofree:.+]] = load i8*, i8** %[[geper]], align 8
; CHECK-NEXT: tail call void @free(i8* nonnull %[[tofree]])
; CHECK-NEXT: tail call void @free(i8* nonnull %[[metaload]])
; CHECK-NEXT: %[[lcmp:.+]] = icmp eq i64 %[[sub]], 0
; CHECK-NEXT: br i1 %[[lcmp]], label %invertentry, label %invertfor.body
; CHECK-NEXT: }
11 changes: 5 additions & 6 deletions enzyme/test/Enzyme/llist.ll
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,13 @@ attributes #4 = { nounwind }
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %n, %invertfor.cond.cleanup.i ], [ %[[sub:.+]], %invertfor.body.i ]
; CHECK-NEXT: %[[sub]] = add i64 %[[antivar]], -1
; CHECK-NEXT: %[[gep:.+]] = getelementptr i8*, i8** %"call'mi_malloccache.i", i64 %[[antivar]]
; CHECK-NEXT: %[[ccast:.+]] = bitcast i8** %[[gep]] to double**
; CHECK-NEXT: %[[loadcache:.+]] = load double*, double** %[[ccast]]
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[loadcache]]
; CHECK-NEXT: %[[loadcache:.+]] = load i8*, i8** %[[gep]]
; CHECK-NEXT: %[[ccast:.+]] = bitcast i8* %[[loadcache]] to double*
; CHECK-NEXT: %[[load:.+]] = load double, double* %[[ccast]]
; this store is optional and could get removed by DCE
; CHECK-NEXT: store double 0.000000e+00, double* %[[loadcache]]
; CHECK-NEXT: store double 0.000000e+00, double* %[[ccast]]
; CHECK-NEXT: %[[add]] = fadd fast double %"x'de.0.i", %[[load]]
; CHECK-NEXT: %[[prefree2:.+]] = load i8*, i8** %[[gep]]
; CHECK-NEXT: call void @free(i8* nonnull %[[prefree2]]) #4
; CHECK-NEXT: call void @free(i8* nonnull %[[loadcache]]) #4
; CHECK-NEXT: %[[gepcall:.+]] = getelementptr i8*, i8** %call_malloccache.i, i64 %[[antivar]]
; CHECK-NEXT: %[[loadprefree:.+]] = load i8*, i8** %[[gepcall]]
; CHECK-NEXT: call void @free(i8* %[[loadprefree]]) #4
Expand Down

0 comments on commit 7743e46

Please sign in to comment.