atomic instruction of global value considered active (should be inactive) #2212
Description
I'm using the EnzymeCreateForwardDiff
function in enzyme to calculate the gradient of a function (rhs
). This function calculates the gradient of a vector-valued function. I'm using threading: the function is designed to be called from different threads, passing in two ints that give the current thread id, and the total number of threads. Each thread computes part of the vector of values
I'm using atomics to implement synchronisation barriers within the function, and this is where my problem is occuring. I have a global value thread_counter
which is used to implement the barrier, and enzyme is incorrectly considering this global as active, but it should have no effect on the values I'm taking the gradient of. Is there a way of forcing enzyme to consider thread_counter
as inactive?
The IR for the function and the error I'm getting is reproduced below:
error: <unknown>:0:0: in function preprocess_rhs void (double, ptr, ptr, ptr, i32, i32): Enzyme: ; Function Attrs: mustprogress nofree nounwind willreturn memory(readwrite, inaccessiblemem: none)
define void @preprocess_rhs(double %0, ptr noalias nocapture readonly %1, ptr noalias nocapture writeonly %2, ptr noalias nocapture writeonly %3, i32 %4, i32 %5) local_unnamed_addr #7 {
entry:
%y = getelementptr inbounds double, ptr %1, i64 50
%i_times_size = mul i32 %4, 50
%start = udiv i32 %i_times_size, %5
%done = icmp ugt i32 %start, 49
br i1 %done, label %exit.thread, label %threading_block
exit.thread: ; preds = %entry
%6 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
%current_value.i74 = load atomic i32, ptr @thread_counter monotonic, align 4
br label %exit67
threading_block: ; preds = %entry
%i_plus_one_times_size = add i32 %i_times_size, 50
%end = udiv i32 %i_plus_one_times_size, %5
%end3 = tail call i32 @llvm.umin.i32(i32 %end, i32 50) #8
%7 = zext i32 %start to i64
%8 = add nuw i32 %start, 1
%umax = tail call i32 @llvm.umax.i32(i32 %end3, i32 %8) #8
%9 = xor i32 %start, -1
%10 = add i32 %umax, %9
%11 = zext i32 %10 to i64
%12 = add nuw nsw i64 %11, 1
%min.iters.check = icmp eq i32 %10, 0
br i1 %min.iters.check, label %r-0.preheader, label %vector.ph
vector.ph: ; preds = %threading_block
%n.vec = and i64 %12, -2
%ind.end = add nuw nsw i64 %n.vec, %7
br label %vector.body
vector.body: ; preds = %vector.body, %vector.ph
%iv = phi i64 [ %iv.next, %vector.body ], [ 0, %vector.ph ]
%13 = shl nuw i64 %iv, 1
%iv.next = add nuw nsw i64 %iv, 1
%offset.idx = add i64 %13, %7
%14 = getelementptr inbounds double, ptr %1, i64 %offset.idx
%wide.load = load <2 x double>, ptr %14, align 8
%15 = getelementptr inbounds double, ptr %y, i64 %offset.idx
%wide.load95 = load <2 x double>, ptr %15, align 8
%16 = fadd <2 x double> %wide.load, %wide.load95
%17 = getelementptr inbounds double, ptr %2, i64 %13
store <2 x double> %16, ptr %17, align 8
%index.next = add nuw i64 %13, 2
%18 = icmp eq i64 %index.next, %n.vec
br i1 %18, label %middle.block, label %vector.body, !llvm.loop !4
middle.block: ; preds = %vector.body
%cmp.n = icmp eq i64 %12, %n.vec
br i1 %cmp.n, label %threading_block23, label %r-0.preheader
r-0.preheader: ; preds = %middle.block, %threading_block
%indvars.iv83.ph = phi i64 [ %7, %threading_block ], [ %ind.end, %middle.block ]
%indvars.iv.ph = phi i64 [ 0, %threading_block ], [ %n.vec, %middle.block ]
br label %r-0
r-0: ; preds = %r-0, %r-0.preheader
%iv1 = phi i64 [ %iv.next2, %r-0 ], [ 0, %r-0.preheader ]
%19 = add nuw nsw i64 %indvars.iv.ph, %iv1
%iv.next2 = add nuw nsw i64 %iv1, 1
%20 = add nuw nsw i64 %indvars.iv83.ph, %iv1
%indvars.iv.next = add nuw nsw i64 %19, 1
%r-06 = getelementptr inbounds double, ptr %1, i64 %20
%r-07 = load double, ptr %r-06, align 8
%r-08 = getelementptr inbounds double, ptr %y, i64 %20
%r-09 = load double, ptr %r-08, align 8
%r-010 = fadd double %r-07, %r-09
%r-012 = getelementptr inbounds double, ptr %2, i64 %19
store double %r-010, ptr %r-012, align 8
%indvars.iv.next84 = add nuw nsw i64 %20, 1
%21 = trunc i64 %indvars.iv.next84 to i32
%r-014 = icmp ugt i32 %end3, %21
br i1 %r-014, label %r-0, label %threading_block23.loopexit, !llvm.loop !5
threading_block23.loopexit: ; preds = %r-0
br label %threading_block23
threading_block23: ; preds = %threading_block23.loopexit, %middle.block
%22 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
%current_value.i = load atomic i32, ptr @thread_counter monotonic, align 4
%23 = add nuw i32 %start, 1
%umax96 = tail call i32 @llvm.umax.i32(i32 %end3, i32 %23) #8
%24 = xor i32 %start, -1
%25 = add i32 %umax96, %24
%26 = zext i32 %25 to i64
%27 = add nuw nsw i64 %26, 1
%min.iters.check99 = icmp eq i32 %25, 0
br i1 %min.iters.check99, label %F-0.preheader, label %vector.ph100
vector.ph100: ; preds = %threading_block23
%n.vec102 = and i64 %27, -2
%ind.end103 = add nuw nsw i64 %n.vec102, %7
%ind.end105 = trunc i64 %n.vec102 to i32
br label %vector.body108
vector.body108: ; preds = %vector.body108, %vector.ph100
%iv3 = phi i64 [ %iv.next4, %vector.body108 ], [ 0, %vector.ph100 ]
%28 = shl nuw i64 %iv3, 1
%iv.next4 = add nuw nsw i64 %iv3, 1
%offset.idx111 = add i64 %28, %7
%29 = getelementptr inbounds double, ptr %1, i64 %offset.idx111
%wide.load112 = load <2 x double>, ptr %29, align 8
%sext = shl i64 %28, 32
%30 = ashr exact i64 %sext, 32
%31 = getelementptr inbounds double, ptr %3, i64 %30
store <2 x double> %wide.load112, ptr %31, align 8
%index.next113 = add nuw i64 %28, 2
%32 = icmp eq i64 %index.next113, %n.vec102
br i1 %32, label %middle.block97, label %vector.body108, !llvm.loop !6
middle.block97: ; preds = %vector.body108
%cmp.n107 = icmp eq i64 %27, %n.vec102
br i1 %cmp.n107, label %F-1.preheader, label %F-0.preheader
F-0.preheader: ; preds = %middle.block97, %threading_block23
%indvars.iv88.ph = phi i64 [ %7, %threading_block23 ], [ %ind.end103, %middle.block97 ]
%next_expr_index3279.ph = phi i32 [ 0, %threading_block23 ], [ %ind.end105, %middle.block97 ]
%33 = zext i32 %next_expr_index3279.ph to i64
br label %F-0
F-0: ; preds = %F-0, %F-0.preheader
%iv5 = phi i64 [ %iv.next6, %F-0 ], [ 0, %F-0.preheader ]
%34 = add i64 %33, %iv5
%iv.next6 = add nuw nsw i64 %iv5, 1
%35 = trunc i64 %34 to i32
%36 = add nuw nsw i64 %indvars.iv88.ph, %iv5
%next_expr_index32 = add i32 %35, 1
%F-033 = getelementptr inbounds double, ptr %1, i64 %36
%F-034 = load double, ptr %F-033, align 8
%37 = sext i32 %35 to i64
%F-036 = getelementptr inbounds double, ptr %3, i64 %37
store double %F-034, ptr %F-036, align 8
%indvars.iv.next89 = add nuw nsw i64 %36, 1
%38 = trunc i64 %indvars.iv.next89 to i32
%F-039 = icmp ugt i32 %end3, %38
br i1 %F-039, label %F-0, label %F-1.preheader.loopexit, !llvm.loop !7
F-1.preheader.loopexit: ; preds = %F-0
br label %F-1.preheader
F-1.preheader: ; preds = %F-1.preheader.loopexit, %middle.block97
%39 = add nuw i32 %start, 1
%umax114 = tail call i32 @llvm.umax.i32(i32 %end3, i32 %39) #8
%40 = xor i32 %start, -1
%41 = add i32 %umax114, %40
%42 = zext i32 %41 to i64
%43 = add nuw nsw i64 %42, 1
%min.iters.check117 = icmp eq i32 %41, 0
br i1 %min.iters.check117, label %F-1.preheader133, label %vector.ph118
vector.ph118: ; preds = %F-1.preheader
%n.vec120 = and i64 %43, -2
%ind.end121 = add nuw nsw i64 %n.vec120, %7
%ind.end123 = trunc i64 %n.vec120 to i32
br label %vector.body126
vector.body126: ; preds = %vector.body126, %vector.ph118
%iv7 = phi i64 [ %iv.next8, %vector.body126 ], [ 0, %vector.ph118 ]
%44 = shl nuw i64 %iv7, 1
%iv.next8 = add nuw nsw i64 %iv7, 1
%offset.idx129 = add i64 %44, %7
%45 = getelementptr inbounds double, ptr %y, i64 %offset.idx129
%wide.load130 = load <2 x double>, ptr %45, align 8
%46 = shl i64 %44, 32
%sext132 = add i64 %46, 214748364800
%47 = ashr exact i64 %sext132, 32
%48 = getelementptr inbounds double, ptr %3, i64 %47
store <2 x double> %wide.load130, ptr %48, align 8
%index.next131 = add nuw i64 %44, 2
%49 = icmp eq i64 %index.next131, %n.vec120
br i1 %49, label %middle.block115, label %vector.body126, !llvm.loop !8
middle.block115: ; preds = %vector.body126
%cmp.n125 = icmp eq i64 %43, %n.vec120
br i1 %cmp.n125, label %exit67, label %F-1.preheader133
F-1.preheader133: ; preds = %middle.block115, %F-1.preheader
%indvars.iv91.ph = phi i64 [ %7, %F-1.preheader ], [ %ind.end121, %middle.block115 ]
%next_expr_index5881.ph = phi i32 [ 0, %F-1.preheader ], [ %ind.end123, %middle.block115 ]
%50 = zext i32 %next_expr_index5881.ph to i64
br label %F-1
F-1: ; preds = %F-1, %F-1.preheader133
%iv9 = phi i64 [ %iv.next10, %F-1 ], [ 0, %F-1.preheader133 ]
%51 = add i64 %50, %iv9
%iv.next10 = add nuw nsw i64 %iv9, 1
%52 = trunc i64 %51 to i32
%53 = add nuw nsw i64 %indvars.iv91.ph, %iv9
%next_expr_index58 = add i32 %52, 1
%F-159 = getelementptr inbounds double, ptr %y, i64 %53
%F-160 = load double, ptr %F-159, align 8
%F-161 = add i32 %52, 50
%54 = sext i32 %F-161 to i64
%F-162 = getelementptr inbounds double, ptr %3, i64 %54
store double %F-160, ptr %F-162, align 8
%indvars.iv.next92 = add nuw nsw i64 %53, 1
%55 = trunc i64 %indvars.iv.next92 to i32
%F-165 = icmp ugt i32 %end3, %55
br i1 %F-165, label %F-1, label %exit67.loopexit, !llvm.loop !9
exit67.loopexit: ; preds = %F-1
br label %exit67
exit67: ; preds = %exit67.loopexit, %middle.block115, %exit.thread
ret void
}
%6 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
Active atomic inst not yet handled
Activity