Skip to content

atomic instruction of global value considered active (should be inactive) #2212

Open
@martinjrobins

Description

I'm using the EnzymeCreateForwardDiff function in enzyme to calculate the gradient of a function (rhs). This function calculates the gradient of a vector-valued function. I'm using threading: the function is designed to be called from different threads, passing in two ints that give the current thread id, and the total number of threads. Each thread computes part of the vector of values

I'm using atomics to implement synchronisation barriers within the function, and this is where my problem is occuring. I have a global value thread_counter which is used to implement the barrier, and enzyme is incorrectly considering this global as active, but it should have no effect on the values I'm taking the gradient of. Is there a way of forcing enzyme to consider thread_counter as inactive?

The IR for the function and the error I'm getting is reproduced below:

error: <unknown>:0:0: in function preprocess_rhs void (double, ptr, ptr, ptr, i32, i32): Enzyme: ; Function Attrs: mustprogress nofree nounwind willreturn memory(readwrite, inaccessiblemem: none)
define void @preprocess_rhs(double %0, ptr noalias nocapture readonly %1, ptr noalias nocapture writeonly %2, ptr noalias nocapture writeonly %3, i32 %4, i32 %5) local_unnamed_addr #7 {
entry:
  %y = getelementptr inbounds double, ptr %1, i64 50
  %i_times_size = mul i32 %4, 50
  %start = udiv i32 %i_times_size, %5
  %done = icmp ugt i32 %start, 49
  br i1 %done, label %exit.thread, label %threading_block

exit.thread:                                      ; preds = %entry
  %6 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
  %current_value.i74 = load atomic i32, ptr @thread_counter monotonic, align 4
  br label %exit67

threading_block:                                  ; preds = %entry
  %i_plus_one_times_size = add i32 %i_times_size, 50
  %end = udiv i32 %i_plus_one_times_size, %5
  %end3 = tail call i32 @llvm.umin.i32(i32 %end, i32 50) #8
  %7 = zext i32 %start to i64
  %8 = add nuw i32 %start, 1
  %umax = tail call i32 @llvm.umax.i32(i32 %end3, i32 %8) #8
  %9 = xor i32 %start, -1
  %10 = add i32 %umax, %9
  %11 = zext i32 %10 to i64
  %12 = add nuw nsw i64 %11, 1
  %min.iters.check = icmp eq i32 %10, 0
  br i1 %min.iters.check, label %r-0.preheader, label %vector.ph

vector.ph:                                        ; preds = %threading_block
  %n.vec = and i64 %12, -2
  %ind.end = add nuw nsw i64 %n.vec, %7
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %iv = phi i64 [ %iv.next, %vector.body ], [ 0, %vector.ph ]
  %13 = shl nuw i64 %iv, 1
  %iv.next = add nuw nsw i64 %iv, 1
  %offset.idx = add i64 %13, %7
  %14 = getelementptr inbounds double, ptr %1, i64 %offset.idx
  %wide.load = load <2 x double>, ptr %14, align 8
  %15 = getelementptr inbounds double, ptr %y, i64 %offset.idx
  %wide.load95 = load <2 x double>, ptr %15, align 8
  %16 = fadd <2 x double> %wide.load, %wide.load95
  %17 = getelementptr inbounds double, ptr %2, i64 %13
  store <2 x double> %16, ptr %17, align 8
  %index.next = add nuw i64 %13, 2
  %18 = icmp eq i64 %index.next, %n.vec
  br i1 %18, label %middle.block, label %vector.body, !llvm.loop !4

middle.block:                                     ; preds = %vector.body
  %cmp.n = icmp eq i64 %12, %n.vec
  br i1 %cmp.n, label %threading_block23, label %r-0.preheader

r-0.preheader:                                    ; preds = %middle.block, %threading_block
  %indvars.iv83.ph = phi i64 [ %7, %threading_block ], [ %ind.end, %middle.block ]
  %indvars.iv.ph = phi i64 [ 0, %threading_block ], [ %n.vec, %middle.block ]
  br label %r-0

r-0:                                              ; preds = %r-0, %r-0.preheader
  %iv1 = phi i64 [ %iv.next2, %r-0 ], [ 0, %r-0.preheader ]
  %19 = add nuw nsw i64 %indvars.iv.ph, %iv1
  %iv.next2 = add nuw nsw i64 %iv1, 1
  %20 = add nuw nsw i64 %indvars.iv83.ph, %iv1
  %indvars.iv.next = add nuw nsw i64 %19, 1
  %r-06 = getelementptr inbounds double, ptr %1, i64 %20
  %r-07 = load double, ptr %r-06, align 8
  %r-08 = getelementptr inbounds double, ptr %y, i64 %20
  %r-09 = load double, ptr %r-08, align 8
  %r-010 = fadd double %r-07, %r-09
  %r-012 = getelementptr inbounds double, ptr %2, i64 %19
  store double %r-010, ptr %r-012, align 8
  %indvars.iv.next84 = add nuw nsw i64 %20, 1
  %21 = trunc i64 %indvars.iv.next84 to i32
  %r-014 = icmp ugt i32 %end3, %21
  br i1 %r-014, label %r-0, label %threading_block23.loopexit, !llvm.loop !5

threading_block23.loopexit:                       ; preds = %r-0
  br label %threading_block23

threading_block23:                                ; preds = %threading_block23.loopexit, %middle.block
  %22 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
  %current_value.i = load atomic i32, ptr @thread_counter monotonic, align 4
  %23 = add nuw i32 %start, 1
  %umax96 = tail call i32 @llvm.umax.i32(i32 %end3, i32 %23) #8
  %24 = xor i32 %start, -1
  %25 = add i32 %umax96, %24
  %26 = zext i32 %25 to i64
  %27 = add nuw nsw i64 %26, 1
  %min.iters.check99 = icmp eq i32 %25, 0
  br i1 %min.iters.check99, label %F-0.preheader, label %vector.ph100

vector.ph100:                                     ; preds = %threading_block23
  %n.vec102 = and i64 %27, -2
  %ind.end103 = add nuw nsw i64 %n.vec102, %7
  %ind.end105 = trunc i64 %n.vec102 to i32
  br label %vector.body108

vector.body108:                                   ; preds = %vector.body108, %vector.ph100
  %iv3 = phi i64 [ %iv.next4, %vector.body108 ], [ 0, %vector.ph100 ]
  %28 = shl nuw i64 %iv3, 1
  %iv.next4 = add nuw nsw i64 %iv3, 1
  %offset.idx111 = add i64 %28, %7
  %29 = getelementptr inbounds double, ptr %1, i64 %offset.idx111
  %wide.load112 = load <2 x double>, ptr %29, align 8
  %sext = shl i64 %28, 32
  %30 = ashr exact i64 %sext, 32
  %31 = getelementptr inbounds double, ptr %3, i64 %30
  store <2 x double> %wide.load112, ptr %31, align 8
  %index.next113 = add nuw i64 %28, 2
  %32 = icmp eq i64 %index.next113, %n.vec102
  br i1 %32, label %middle.block97, label %vector.body108, !llvm.loop !6

middle.block97:                                   ; preds = %vector.body108
  %cmp.n107 = icmp eq i64 %27, %n.vec102
  br i1 %cmp.n107, label %F-1.preheader, label %F-0.preheader

F-0.preheader:                                    ; preds = %middle.block97, %threading_block23
  %indvars.iv88.ph = phi i64 [ %7, %threading_block23 ], [ %ind.end103, %middle.block97 ]
  %next_expr_index3279.ph = phi i32 [ 0, %threading_block23 ], [ %ind.end105, %middle.block97 ]
  %33 = zext i32 %next_expr_index3279.ph to i64
  br label %F-0

F-0:                                              ; preds = %F-0, %F-0.preheader
  %iv5 = phi i64 [ %iv.next6, %F-0 ], [ 0, %F-0.preheader ]
  %34 = add i64 %33, %iv5
  %iv.next6 = add nuw nsw i64 %iv5, 1
  %35 = trunc i64 %34 to i32
  %36 = add nuw nsw i64 %indvars.iv88.ph, %iv5
  %next_expr_index32 = add i32 %35, 1
  %F-033 = getelementptr inbounds double, ptr %1, i64 %36
  %F-034 = load double, ptr %F-033, align 8
  %37 = sext i32 %35 to i64
  %F-036 = getelementptr inbounds double, ptr %3, i64 %37
  store double %F-034, ptr %F-036, align 8
  %indvars.iv.next89 = add nuw nsw i64 %36, 1
  %38 = trunc i64 %indvars.iv.next89 to i32
  %F-039 = icmp ugt i32 %end3, %38
  br i1 %F-039, label %F-0, label %F-1.preheader.loopexit, !llvm.loop !7

F-1.preheader.loopexit:                           ; preds = %F-0
  br label %F-1.preheader

F-1.preheader:                                    ; preds = %F-1.preheader.loopexit, %middle.block97
  %39 = add nuw i32 %start, 1
  %umax114 = tail call i32 @llvm.umax.i32(i32 %end3, i32 %39) #8
  %40 = xor i32 %start, -1
  %41 = add i32 %umax114, %40
  %42 = zext i32 %41 to i64
  %43 = add nuw nsw i64 %42, 1
  %min.iters.check117 = icmp eq i32 %41, 0
  br i1 %min.iters.check117, label %F-1.preheader133, label %vector.ph118

vector.ph118:                                     ; preds = %F-1.preheader
  %n.vec120 = and i64 %43, -2
  %ind.end121 = add nuw nsw i64 %n.vec120, %7
  %ind.end123 = trunc i64 %n.vec120 to i32
  br label %vector.body126

vector.body126:                                   ; preds = %vector.body126, %vector.ph118
  %iv7 = phi i64 [ %iv.next8, %vector.body126 ], [ 0, %vector.ph118 ]
  %44 = shl nuw i64 %iv7, 1
  %iv.next8 = add nuw nsw i64 %iv7, 1
  %offset.idx129 = add i64 %44, %7
  %45 = getelementptr inbounds double, ptr %y, i64 %offset.idx129
  %wide.load130 = load <2 x double>, ptr %45, align 8
  %46 = shl i64 %44, 32
  %sext132 = add i64 %46, 214748364800
  %47 = ashr exact i64 %sext132, 32
  %48 = getelementptr inbounds double, ptr %3, i64 %47
  store <2 x double> %wide.load130, ptr %48, align 8
  %index.next131 = add nuw i64 %44, 2
  %49 = icmp eq i64 %index.next131, %n.vec120
  br i1 %49, label %middle.block115, label %vector.body126, !llvm.loop !8

middle.block115:                                  ; preds = %vector.body126
  %cmp.n125 = icmp eq i64 %43, %n.vec120
  br i1 %cmp.n125, label %exit67, label %F-1.preheader133

F-1.preheader133:                                 ; preds = %middle.block115, %F-1.preheader
  %indvars.iv91.ph = phi i64 [ %7, %F-1.preheader ], [ %ind.end121, %middle.block115 ]
  %next_expr_index5881.ph = phi i32 [ 0, %F-1.preheader ], [ %ind.end123, %middle.block115 ]
  %50 = zext i32 %next_expr_index5881.ph to i64
  br label %F-1

F-1:                                              ; preds = %F-1, %F-1.preheader133
  %iv9 = phi i64 [ %iv.next10, %F-1 ], [ 0, %F-1.preheader133 ]
  %51 = add i64 %50, %iv9
  %iv.next10 = add nuw nsw i64 %iv9, 1
  %52 = trunc i64 %51 to i32
  %53 = add nuw nsw i64 %indvars.iv91.ph, %iv9
  %next_expr_index58 = add i32 %52, 1
  %F-159 = getelementptr inbounds double, ptr %y, i64 %53
  %F-160 = load double, ptr %F-159, align 8
  %F-161 = add i32 %52, 50
  %54 = sext i32 %F-161 to i64
  %F-162 = getelementptr inbounds double, ptr %3, i64 %54
  store double %F-160, ptr %F-162, align 8
  %indvars.iv.next92 = add nuw nsw i64 %53, 1
  %55 = trunc i64 %indvars.iv.next92 to i32
  %F-165 = icmp ugt i32 %end3, %55
  br i1 %F-165, label %F-1, label %exit67.loopexit, !llvm.loop !9

exit67.loopexit:                                  ; preds = %F-1
  br label %exit67

exit67:                                           ; preds = %exit67.loopexit, %middle.block115, %exit.thread
  ret void
}

  %6 = atomicrmw add ptr @thread_counter, i32 1 monotonic, align 4
 Active atomic inst not yet handled

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions