Skip to content

Commit

Permalink
atomics: Add warp-aggregated atomic increment
Browse files Browse the repository at this point in the history
Faster atomic counter increment using warp-aggregated atomics. Useful
for filtering.

Adapted from:
https://developer.nvidia.com/blog/cuda-pro-tip-optimized-filtering-warp-aggregated-atomics/
  • Loading branch information
ahendriksen committed Jul 8, 2022
1 parent 2b27bad commit 4a76a73
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions cpp/include/raft/device_atomics.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
*/

#include <type_traits>
#include <cooperative_groups.h>

namespace raft {

Expand Down Expand Up @@ -636,3 +637,32 @@ __forceinline__ __device__ T atomicXor(T* address, T val)
{
return raft::genericAtomicOperation(address, val, raft::device_atomics::detail::DeviceXor{});
}

/**
* @brief: Warp aggregated atomic increment
*
* increments an atomic counter using all active threads in a warp. The return
* value is the original value of the counter plus the rank of the calling
* thread.
*
* The use of atomicIncWarp is a performance optimization. It can reduce the
* amount of atomic memory traffic by a factor of 32.
*
* Adapted from:
* https://developer.nvidia.com/blog/cuda-pro-tip-optimized-filtering-warp-aggregated-atomics/
*
* @tparam T An integral type
* @param[in,out] ctr The address of old value
*
* @return The old value of the counter plus the rank of the calling thread.
*/
template <typename T = unsigned int,
typename std::enable_if_t<std::is_integral<T>::value, T>* = nullptr>
__device__ T atomicIncWarp(T* ctr)
{
namespace cg = cooperative_groups;
auto g = cg::coalesced_threads();
T warp_res;
if (g.thread_rank() == 0) { warp_res = atomicAdd(ctr, static_cast<T>(g.size())); }
return g.shfl(warp_res, 0) + g.thread_rank();
}

0 comments on commit 4a76a73

Please sign in to comment.