Skip to content

Commit bf3958d

Browse files
committed
JLInstSimplify multi arg
1 parent 636db9f commit bf3958d

File tree

2 files changed

+126
-15
lines changed

2 files changed

+126
-15
lines changed

enzyme/Enzyme/JLInstSimplify.cpp

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,58 @@ bool notCapturedBefore(llvm::Value *V, Instruction *inst) {
129129
return true;
130130
}
131131

132+
static inline SetVector<llvm::Value *> getBaseObjects(llvm::Value *V,
133+
bool offsetAllowed) {
134+
SetVector<llvm::Value *> results;
135+
136+
SmallPtrSet<llvm::Value *, 2> seen;
137+
SmallVector<llvm::Value *, 1> todo = {V};
138+
139+
while (todo.size()) {
140+
auto cur = todo.back();
141+
todo.pop_back();
142+
if (seen.count(cur))
143+
continue;
144+
seen.insert(cur);
145+
auto obj = getBaseObject(cur, offsetAllowed);
146+
if (auto PN = dyn_cast<PHINode>(obj)) {
147+
for (auto &val : PN->incoming_values()) {
148+
todo.push_back(val);
149+
}
150+
continue;
151+
}
152+
if (auto SI = dyn_cast<SelectInst>(obj)) {
153+
todo.push_back(SI->getTrueValue());
154+
todo.push_back(SI->getFalseValue());
155+
continue;
156+
}
157+
results.insert(obj);
158+
}
159+
return results;
160+
}
161+
162+
bool noaliased_or_arg(SetVector<llvm::Value *> &lhs_v,
163+
SetVector<llvm::Value *> &rhs_v) {
164+
for (auto lhs : lhs_v) {
165+
auto lhs_na = isNoAlias(lhs);
166+
auto lhs_arg = isa<Argument>(lhs);
167+
168+
// This LHS value is neither noalias or an argument
169+
if (!lhs_na && !lhs_arg)
170+
return false;
171+
172+
for (auto rhs : rhs_v) {
173+
if (lhs == rhs)
174+
return false;
175+
if (isNoAlias(lhs))
176+
continue;
177+
if (!lhs_na && !isa<Argument>(rhs))
178+
return false;
179+
}
180+
}
181+
return true;
182+
}
183+
132184
bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
133185
llvm::AAResults &AA, llvm::LoopInfo &LI) {
134186
bool changed = false;
@@ -175,33 +227,59 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
175227
}
176228

177229
if (legal) {
178-
auto lhs = getBaseObject(I.getOperand(0), /*offsetAllowed*/ false);
179-
auto rhs = getBaseObject(I.getOperand(1), /*offsetAllowed*/ false);
180-
if (lhs == rhs) {
230+
auto lhs_v = getBaseObjects(I.getOperand(0), /*offsetAllowed*/ false);
231+
auto rhs_v = getBaseObjects(I.getOperand(1), /*offsetAllowed*/ false);
232+
if (lhs_v.size() == 1 && rhs_v.size() == 1 && lhs_v[0] == rhs_v[0]) {
181233
auto repval = ICmpInst::isTrueWhenEqual(pred)
182234
? ConstantInt::get(I.getType(), 1)
183235
: ConstantInt::get(I.getType(), 0);
184236
I.replaceAllUsesWith(repval);
185237
changed = true;
186238
continue;
187239
}
188-
if ((isNoAlias(lhs) && (isNoAlias(rhs) || isa<Argument>(rhs))) ||
189-
(isNoAlias(rhs) && isa<Argument>(lhs))) {
240+
if (noaliased_or_arg(lhs_v, rhs_v)) {
190241
auto repval = ICmpInst::isTrueWhenEqual(pred)
191242
? ConstantInt::get(I.getType(), 0)
192243
: ConstantInt::get(I.getType(), 1);
193244
I.replaceAllUsesWith(repval);
194245
changed = true;
195246
continue;
196247
}
197-
auto llhs = dyn_cast<LoadInst>(lhs);
198-
auto lrhs = dyn_cast<LoadInst>(rhs);
199-
if (llhs && lrhs && isa<PointerType>(llhs->getType()) &&
200-
isa<PointerType>(lrhs->getType())) {
201-
auto lhsv =
202-
getBaseObject(llhs->getOperand(0), /*offsetAllowed*/ false);
203-
auto rhsv =
204-
getBaseObject(lrhs->getOperand(0), /*offsetAllowed*/ false);
248+
bool loadlegal = true;
249+
SmallVector<LoadInst *, 1> llhs, lrhs;
250+
for (auto lhs : lhs_v) {
251+
auto ld = dyn_cast<LoadInst>(lhs);
252+
if (!ld || !isa<PointerType>(ld->getType())) {
253+
loadlegal = false;
254+
break;
255+
}
256+
llhs.push_back(ld);
257+
}
258+
for (auto rhs : rhs_v) {
259+
auto ld = dyn_cast<LoadInst>(rhs);
260+
if (!ld || !isa<PointerType>(ld->getType())) {
261+
loadlegal = false;
262+
break;
263+
}
264+
lrhs.push_back(ld);
265+
}
266+
SetVector<Value *> llhs_s, lrhs_s;
267+
for (auto v : llhs) {
268+
for (auto obj :
269+
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
270+
llhs_s.insert(obj);
271+
}
272+
}
273+
for (auto v : lrhs) {
274+
for (auto obj :
275+
getBaseObjects(v->getOperand(0), /*offsetAllowed*/ false)) {
276+
lrhs_s.insert(obj);
277+
}
278+
}
279+
// TODO handle multi size
280+
if (llhs_s.size() == 1 && lrhs_s.size() == 1 && loadlegal) {
281+
auto lhsv = llhs_s[0];
282+
auto rhsv = lrhs_s[0];
205283
if ((isNoAlias(lhsv) && (isNoAlias(rhsv) || isa<Argument>(rhsv) ||
206284
notCapturedBefore(lhsv, &I))) ||
207285
(isNoAlias(rhsv) &&
@@ -225,8 +303,15 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
225303
if (!I->mayWriteToMemory())
226304
return /*earlyBreak*/ false;
227305

228-
for (auto LI : {llhs, lrhs})
229-
if (writesToMemoryReadBy(AA, TLI,
306+
for (auto LI : llhs)
307+
if (writesToMemoryReadBy(nullptr, AA, TLI,
308+
/*maybeReader*/ LI,
309+
/*maybeWriter*/ I)) {
310+
overwritten = true;
311+
return /*earlyBreak*/ true;
312+
}
313+
for (auto LI : lrhs)
314+
if (writesToMemoryReadBy(nullptr, AA, TLI,
230315
/*maybeReader*/ LI,
231316
/*maybeWriter*/ I)) {
232317
overwritten = true;
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -jl-inst-simplify -S | FileCheck %s; fi
2+
; RUN: %opt < %s %newLoadEnzyme -passes="jl-inst-simplify" -S | FileCheck %s
3+
4+
declare i8** @malloc(i64)
5+
6+
define fastcc i1 @augmented_julia__affine_normalize_1484(i1 %c) {
7+
%i5 = call noalias i8** @malloc(i64 16)
8+
br i1 %c, label %tval, label %fval
9+
10+
tval:
11+
%j29 = load i8*, i8** %i5, align 8
12+
br label %end
13+
14+
fval:
15+
%k29 = load i8*, i8** %i5, align 8
16+
br label %end
17+
18+
end:
19+
%i29 = phi i8* [ %j29, %tval ], [ %k29, %fval ]
20+
%i31 = call noalias nonnull i8* addrspace(10)* inttoptr (i64 137352001798896 to i8* addrspace(10)* ({} addrspace(10)*, i64, i64)*)({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 137351863426640 to {}*) to {} addrspace(10)*), i64 10, i64 10)
21+
%i35 = load i8*, i8* addrspace(10)* %i31, align 8
22+
%i39 = icmp ne i8* %i35, %i29
23+
ret i1 %i39
24+
}
25+
26+
; CHECK: ret i1 true

0 commit comments

Comments
 (0)