Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix calling convention and other small bugs #16

Merged
merged 12 commits into from
Oct 15, 2019
Prev Previous commit
Next Next commit
revamp loop calc
  • Loading branch information
wsmoses committed Oct 12, 2019
commit a8bdf3b1b5dce14911f871bf10c6060033b8baf9
57 changes: 33 additions & 24 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ std::pair<PHINode*,Value*> insertNewCanonicalIV(Loop* L, Type* Ty) {
return std::pair<PHINode*,Value*>(CanonicalIV,inc);
}

void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment=nullptr) {
void removeRedundantIVs(const Loop* L, BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment, const SmallVectorImpl<BasicBlock*>&& latches) {
assert(Header);
assert(CanonicalIV);

Expand Down Expand Up @@ -535,26 +535,30 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
gutils.erase(PN);
}

if (latches.size() == 1 && isa<BranchInst>(latches[0]->getTerminator()) && cast<BranchInst>(latches[0]->getTerminator())->isConditional())
for (auto use : CanonicalIV->users()) {
if (auto cmp = dyn_cast<ICmpInst>(use)) {
if (cmp->isUnsigned()) {
// Force i to be on LHS
if (cmp->getOperand(0) != CanonicalIV) {
//Below also swaps predicate correctly
cmp->swapOperands();
}
assert(cmp->getOperand(0) == CanonicalIV);
if (cast<BranchInst>(latches[0]->getTerminator())->getCondition() != cmp) continue;
// Force i to be on LHS
if (cmp->getOperand(0) != CanonicalIV) {
//Below also swaps predicate correctly
cmp->swapOperands();
}
assert(cmp->getOperand(0) == CanonicalIV);

auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {

// valid replacements (since unsigned comparison and i starts at 0 counting up)

// * i < n => i != n, valid since first time i >= n occurs at i == n
if (cmp->getPredicate() == ICmpInst::ICMP_ULT) {
if (cmp->getPredicate() == ICmpInst::ICMP_ULT || cmp->getPredicate() == ICmpInst::ICMP_SLT) {
cmp->setPredicate(ICmpInst::ICMP_NE);
goto cend;
}

// * i <= n => i != n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
if (cmp->getPredicate() == ICmpInst::ICMP_ULE) {
if (cmp->getPredicate() == ICmpInst::ICMP_ULE || cmp->getPredicate() == ICmpInst::ICMP_SLE) {
IRBuilder <>builder (Preheader->getTerminator());
if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
builder.SetInsertPoint(inst->getNextNode());
Expand All @@ -565,13 +569,13 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
}

// * i >= n => i == n, valid since first time i >= n occurs at i == n
if (cmp->getPredicate() == ICmpInst::ICMP_UGE) {
if (cmp->getPredicate() == ICmpInst::ICMP_UGE || cmp->getPredicate() == ICmpInst::ICMP_SGE) {
cmp->setPredicate(ICmpInst::ICMP_EQ);
goto cend;
}

// * i > n => i == n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
if (cmp->getPredicate() == ICmpInst::ICMP_UGT) {
if (cmp->getPredicate() == ICmpInst::ICMP_UGT || cmp->getPredicate() == ICmpInst::ICMP_SGT) {
IRBuilder <>builder (Preheader->getTerminator());
if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
builder.SetInsertPoint(inst->getNextNode());
Expand Down Expand Up @@ -618,39 +622,44 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
gutils.erase(inst);
}

if (latches.size() == 1 && isa<BranchInst>(latches[0]->getTerminator()) && cast<BranchInst>(latches[0]->getTerminator())->isConditional())
for (auto use : increment->users()) {
if (auto cmp = dyn_cast<ICmpInst>(use)) {
if (cmp->isUnsigned()) {
// Force i+1 to be on LHS
if (cmp->getOperand(0) != increment) {
//Below also swaps predicate correctly
cmp->swapOperands();
}
assert(cmp->getOperand(0) == increment);
if (cast<BranchInst>(latches[0]->getTerminator())->getCondition() != cmp) continue;

// Force i+1 to be on LHS
if (cmp->getOperand(0) != increment) {
//Below also swaps predicate correctly
cmp->swapOperands();
}
assert(cmp->getOperand(0) == increment);

auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {

// valid replacements (since unsigned comparison and i starts at 0 counting up)

// * i+1 < n => i+1 != n, valid since first time i+1 >= n occurs at i+1 == n
if (cmp->getPredicate() == ICmpInst::ICMP_ULT) {
if (cmp->getPredicate() == ICmpInst::ICMP_ULT || cmp->getPredicate() == ICmpInst::ICMP_SLT) {
cmp->setPredicate(ICmpInst::ICMP_NE);
continue;
}

// * i+1 <= n => i != n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
if (cmp->getPredicate() == ICmpInst::ICMP_ULE) {
if (cmp->getPredicate() == ICmpInst::ICMP_ULE || cmp->getPredicate() == ICmpInst::ICMP_SLE) {
cmp->setOperand(0, CanonicalIV);
cmp->setPredicate(ICmpInst::ICMP_NE);
continue;
}

// * i+1 >= n => i+1 == n, valid since first time i+1 >= n occurs at i+1 == n
if (cmp->getPredicate() == ICmpInst::ICMP_UGE) {
if (cmp->getPredicate() == ICmpInst::ICMP_UGE || cmp->getPredicate() == ICmpInst::ICMP_SGE) {
cmp->setPredicate(ICmpInst::ICMP_EQ);
continue;
}

// * i+1 > n => i == n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
if (cmp->getPredicate() == ICmpInst::ICMP_UGT) {
if (cmp->getPredicate() == ICmpInst::ICMP_UGT || cmp->getPredicate() == ICmpInst::ICMP_SGT) {
cmp->setOperand(0, CanonicalIV);
cmp->setPredicate(ICmpInst::ICMP_EQ);
continue;
Expand Down Expand Up @@ -686,7 +695,7 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
auto pair = insertNewCanonicalIV(L, Type::getInt64Ty(BB->getContext()));
PHINode* CanonicalIV = pair.first;
assert(CanonicalIV);
removeRedundantIVs(loopContexts[L].header, loopContexts[L].preheader, CanonicalIV, SE, gutils, pair.second);
removeRedundantIVs(L, loopContexts[L].header, loopContexts[L].preheader, CanonicalIV, SE, gutils, pair.second, fake::SCEVExpander::getLatches(L, loopContexts[L].exitBlocks));
loopContexts[L].var = CanonicalIV;
loopContexts[L].antivar = PHINode::Create(CanonicalIV->getType(), CanonicalIV->getNumIncomingValues(), CanonicalIV->getName()+"'phi");

Expand Down
10 changes: 5 additions & 5 deletions enzyme/test/Enzyme/llist.ll
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ attributes #4 = { nounwind }
; CHECK-NEXT: %"'ipl" = load %struct.n*, %struct.n** %"next'ipg", align 8
; CHECK-NEXT: %[[loadst]] = load %struct.n*, %struct.n** %next, align 8, !tbaa !8
; CHECK-NEXT: %cmp = icmp eq %struct.n* %[[loadst]], null
; CHECK-NEXT: br i1 %cmp, label %invertfor.body, label %for.body
; CHECK-NEXT: br i1 %cmp, label %[[antiloop:.+]], label %for.body

; CHECK: invertentry:
; CHECK-NEXT: ret {} undef

; CHECK: invertfor.body.preheader: ; preds = %invertfor.body
; CHECK: invertfor.body.preheader:
; CHECK-NEXT: tail call void @free(i8* nonnull %_realloccache)
; CHECK-NEXT: br label %invertentry

; CHECK: invertfor.body:
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[subidx:.+]], %invertfor.body ], [ %[[preidx]], %for.body ]
; CHECK: [[antiloop]]:
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[subidx:.+]], %[[antiloop]] ], [ %[[preidx]], %for.body ]
; CHECK-NEXT: %[[subidx]] = add i64 %[[antivar]], -1
; CHECK-NEXT: %[[structptr:.+]] = getelementptr %struct.n*, %struct.n** %[[bcalloc]], i64 %[[antivar]]
; CHECK-NEXT: %[[struct:.+]] = load %struct.n*, %struct.n** %[[structptr]]
Expand All @@ -191,5 +191,5 @@ attributes #4 = { nounwind }
; CHECK-NEXT: %[[addval:.+]] = fadd fast double %[[val0]], %[[differet]]
; CHECK-NEXT: store double %[[addval]], double* %"value'ipg"
; CHECK-NEXT: %[[cmpeq:.+]] = icmp eq i64 %[[antivar]], 0
; CHECK-NEXT: br i1 %[[cmpeq]], label %invertfor.body.preheader, label %invertfor.body
; CHECK-NEXT: br i1 %[[cmpeq]], label %invertfor.body.preheader, label %[[antiloop]]
; CHECK-NEXT: }
8 changes: 4 additions & 4 deletions enzyme/test/Enzyme/nllist.ll
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ attributes #4 = { nounwind }
; CHECK-NEXT: %[[dstructload]] = load %struct.n*, %struct.n** %"next'ipg", align 8
; CHECK-NEXT: %[[nextstruct]] = load %struct.n*, %struct.n** %next, align 8, !tbaa !7
; CHECK-NEXT: %[[mycmp:.+]] = icmp eq %struct.n* %[[nextstruct]], null
; CHECK-NEXT: br i1 %[[mycmp]], label %invertfor.cond.cleanup4, label %for.cond1.preheader
; CHECK-NEXT: br i1 %[[mycmp]], label %[[invertforcondcleanup:.+]], label %for.cond1.preheader

; CHECK: for.body5: ; preds = %for.body5, %for.cond1.preheader
; CHECK-NEXT: %[[iv:.+]] = phi i64 [ %[[ivnext:.+]], %for.body5 ], [ 0, %for.cond1.preheader ]
Expand All @@ -349,17 +349,17 @@ attributes #4 = { nounwind }

; CHECK: invertfor.cond1.preheader: ; preds = %invertfor.body5
; CHECK-NEXT: %[[icmp:.+]] = icmp eq i64 %[[antivar:.+]], 0
; CHECK-NEXT: br i1 %[[icmp]], label %invertfor.cond1.preheader.preheader, label %invertfor.cond.cleanup4
; CHECK-NEXT: br i1 %[[icmp]], label %invertfor.cond1.preheader.preheader, label %[[invertforcondcleanup]]

; CHECK: invertfor.cond.cleanup4:
; CHECK: [[invertforcondcleanup]]:
; CHECK-NEXT: %[[antivar]] = phi i64 [ %[[isub:.+]], %invertfor.cond1.preheader ], [ %[[preidx]], %for.cond.cleanup4 ]
; CHECK-NEXT: %[[isub]] = add i64 %[[antivar]], -1
; CHECK-NEXT: %[[toload:.+]] = getelementptr double*, double** %[[todoublep]], i64 %[[antivar]]
; CHECK-NEXT: %[[loadediv:.+]] = load double*, double** %[[toload]], align 8, !invariant.load
; CHECK-NEXT: br label %invertfor.body5

; CHECK: invertfor.body5:
; CHECK-NEXT: %[[mantivar:.+]] = phi i64 [ %times, %invertfor.cond.cleanup4 ], [ %[[idxsub:.+]], %invertfor.body5 ]
; CHECK-NEXT: %[[mantivar:.+]] = phi i64 [ %times, %[[invertforcondcleanup]] ], [ %[[idxsub:.+]], %invertfor.body5 ]
; CHECK-NEXT: %[[idxsub]] = add i64 %[[mantivar]], -1
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr double, double* %[[loadediv]], i64 %[[mantivar]]
; CHECK-NEXT: %[[arrayload:.+]] = load double, double* %"arrayidx'ipg"
Expand Down
8 changes: 4 additions & 4 deletions enzyme/test/Enzyme/sumbr2.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -instcombine -simplifycfg -S | FileCheck %s
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -instcombine -simplifycfg -S -jump-threading -instsimplify -simplifycfg -adce -loop-deletion -simplifycfg | FileCheck %s

; Function Attrs: norecurse nounwind readonly uwtable
define dso_local double @sum(double* nocapture readonly %x, i64 %n) #0 {
Expand Down Expand Up @@ -45,15 +45,15 @@ attributes #2 = { nounwind }
; CHECK-NEXT: br i1 %[[exists]], label %diffesum.exit, label %[[antiloop:.+]]

; CHECK: [[antiloop]]:
; CHECK-NEXT: %"add'de.0.i" = phi double [ %[[m0dadd:.+]], %[[antiloop]] ], [ 1.000000e+00, %entry ]
; CHECK-NEXT: %[[dadd:.+]] = phi double [ %[[m0dadd:.+]], %[[antiloop]] ], [ 1.000000e+00, %entry ]
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[sub:.+]], %[[antiloop]] ], [ %n, %entry ]
; CHECK-NEXT: %[[sub]] = add i64 %[[antivar]], -1
; CHECK-NEXT: %"arrayidx'ipg.i" = getelementptr double, double* %xp, i64 %[[antivar]]
; CHECK-NEXT: %[[toload:.+]] = load double, double* %"arrayidx'ipg.i", align 8
; CHECK-NEXT: %[[tostore:.+]] = fadd fast double %[[toload]], %"add'de.0.i"
; CHECK-NEXT: %[[tostore:.+]] = fadd fast double %[[toload]], %[[dadd]]
; CHECK-NEXT: store double %[[tostore]], double* %"arrayidx'ipg.i", align 8
; CHECK-NEXT: %res_unwrap.i = uitofp i64 %[[sub]] to double
; CHECK-NEXT: %[[m0dadd]] = fmul fast double %"add'de.0.i", %res_unwrap.i
; CHECK-NEXT: %[[m0dadd]] = fmul fast double %[[dadd]], %res_unwrap.i
; CHECK-NEXT: %[[itercmp:.+]] = icmp eq i64 %[[sub]], 0
; CHECK-NEXT: br i1 %[[itercmp]], label %diffesum.exit, label %invertextra.i

Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/sumsimple.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S -early-cse | FileCheck %s
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S -early-cse -simplifycfg | FileCheck %s

; Function Attrs: noinline nounwind uwtable
define dso_local void @f(double* %x, double** %y, i64 %n) #0 {
Expand Down
6 changes: 3 additions & 3 deletions enzyme/test/Enzyme/sumwithbreak.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instcombine -correlated-propagation -adce -instcombine -simplifycfg -early-cse -simplifycfg -loop-unroll -instcombine -simplifycfg -gvn -jump-threading -instcombine -S | FileCheck %s
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instcombine -correlated-propagation -adce -instcombine -simplifycfg -early-cse -simplifycfg -loop-unroll -instcombine -simplifycfg -gvn -jump-threading -instcombine -simplifycfg -S | FileCheck %s

; Function Attrs: noinline nounwind uwtable
define dso_local double @f(double* nocapture readonly %x, i64 %n) #0 {
Expand Down Expand Up @@ -57,8 +57,8 @@ attributes #0 = { noinline nounwind uwtable }
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds double, double* %x, i64 %iv
; CHECK-NEXT: %0 = load double, double* %arrayidx4, align 8
; CHECK-NEXT: %add5 = fadd fast double %0, %data.016
; CHECK-NEXT: %cmp = icmp eq i64 %iv, %n
; CHECK-NEXT: br i1 %cmp, label %invertif.end.peel, label %for.body
; CHECK-NEXT: %cmp = icmp ult i64 %iv, %n
; CHECK-NEXT: br i1 %cmp, label %for.body, label %invertif.end.peel

; CHECK: invertentry:
; CHECK-NEXT: ret {} undef
Expand Down