Skip to content

Commit

Permalink
Add lds and sts inline ptx instructions to force vector instruction g…
Browse files Browse the repository at this point in the history
…eneration (#273)

Adds inline ptx assembly for lds & sts instructions for float, float2, float4, double, double2.
This ensures that compiler doesn't mistakenly generate non-vectorized instructions whenever we need it to generate vectorized version.
Also this ensures that we always generate non-generic ld/st instructions eliminating compiler from generating generic ld/st instructions.
These functions now requires the given shmem pointer should be aligned by the vector length, like for float4 lds/sts shmem pointer should be aligned by 16 bytes else it might silently fail or can also give runtime error.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Thejaswi. N. S (https://github.com/teju85)

URL: #273
  • Loading branch information
mdoijade authored Jun 21, 2021
1 parent 806b7fa commit b266d54
Showing 1 changed file with 68 additions and 33 deletions.
101 changes: 68 additions & 33 deletions cpp/include/raft/common/device_loads_stores.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,60 +24,95 @@ namespace raft {
* @defgroup SmemStores Shared memory store operations
* @{
* @brief Stores to shared memory (both vectorized and non-vectorized forms)
* @param[out] addr shared memory address
* requires the given shmem pointer to be aligned by the vector
length, like for float4 lds/sts shmem pointer should be aligned
by 16 bytes else it might silently fail or can also give
runtime error.
* @param[out] addr shared memory address (should be aligned to vector size)
* @param[in] x data to be stored at this address
*/
DI void sts(float* addr, const float& x) { *addr = x; }
DI void sts(float* addr, const float (&x)[1]) { *addr = x[0]; }
DI void sts(float* addr, const float& x) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<float*>(addr));
asm volatile("st.shared.f32 [%0], {%1};" : : "l"(s1), "f"(x));
}
DI void sts(float* addr, const float (&x)[1]) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<float*>(addr));
asm volatile("st.shared.f32 [%0], {%1};" : : "l"(s1), "f"(x[0]));
}
DI void sts(float* addr, const float (&x)[2]) {
float2 v2 = make_float2(x[0], x[1]);
auto* s2 = reinterpret_cast<float2*>(addr);
*s2 = v2;
auto s2 = __cvta_generic_to_shared(reinterpret_cast<float2*>(addr));
asm volatile("st.shared.v2.f32 [%0], {%1, %2};"
:
: "l"(s2), "f"(x[0]), "f"(x[1]));
}
DI void sts(float* addr, const float (&x)[4]) {
float4 v4 = make_float4(x[0], x[1], x[2], x[3]);
auto* s4 = reinterpret_cast<float4*>(addr);
*s4 = v4;
auto s4 = __cvta_generic_to_shared(reinterpret_cast<float4*>(addr));
asm volatile("st.shared.v4.f32 [%0], {%1, %2, %3, %4};"
:
: "l"(s4), "f"(x[0]), "f"(x[1]), "f"(x[2]), "f"(x[3]));
}

DI void sts(double* addr, const double& x) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<double*>(addr));
asm volatile("st.shared.f64 [%0], {%1};" : : "l"(s1), "d"(x));
}
DI void sts(double* addr, const double (&x)[1]) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<double*>(addr));
asm volatile("st.shared.f64 [%0], {%1};" : : "l"(s1), "d"(x[0]));
}
DI void sts(double* addr, const double& x) { *addr = x; }
DI void sts(double* addr, const double (&x)[1]) { *addr = x[0]; }
DI void sts(double* addr, const double (&x)[2]) {
double2 v2 = make_double2(x[0], x[1]);
auto* s2 = reinterpret_cast<double2*>(addr);
*s2 = v2;
auto s2 = __cvta_generic_to_shared(reinterpret_cast<double2*>(addr));
asm volatile("st.shared.v2.f64 [%0], {%1, %2};"
:
: "l"(s2), "d"(x[0]), "d"(x[1]));
}
/** @} */

/**
* @defgroup SmemLoads Shared memory load operations
* @{
* @brief Loads from shared memory (both vectorized and non-vectorized forms)
requires the given shmem pointer to be aligned by the vector
length, like for float4 lds/sts shmem pointer should be aligned
by 16 bytes else it might silently fail or can also give
runtime error.
* @param[out] x the data to be loaded
* @param[in] addr shared memory address from where to load
* (should be aligned to vector size)
*/
DI void lds(float& x, float* addr) { x = *addr; }
DI void lds(float (&x)[1], float* addr) { x[0] = *addr; }
DI void lds(float& x, float* addr) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<float*>(addr));
asm volatile("ld.shared.f32 {%0}, [%1];" : "=f"(x) : "l"(s1));
}
DI void lds(float (&x)[1], float* addr) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<float*>(addr));
asm volatile("ld.shared.f32 {%0}, [%1];" : "=f"(x[0]) : "l"(s1));
}
DI void lds(float (&x)[2], float* addr) {
auto* s2 = reinterpret_cast<float2*>(addr);
auto v2 = *s2;
x[0] = v2.x;
x[1] = v2.y;
auto s2 = __cvta_generic_to_shared(reinterpret_cast<float2*>(addr));
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];"
: "=f"(x[0]), "=f"(x[1])
: "l"(s2));
}
DI void lds(float (&x)[4], float* addr) {
auto* s4 = reinterpret_cast<float4*>(addr);
auto v4 = *s4;
x[0] = v4.x;
x[1] = v4.y;
x[2] = v4.z;
x[3] = v4.w;
}
DI void lds(double& x, double* addr) { x = *addr; }
DI void lds(double (&x)[1], double* addr) { x[0] = *addr; }
auto s4 = __cvta_generic_to_shared(reinterpret_cast<float4*>(addr));
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(x[0]), "=f"(x[1]), "=f"(x[2]), "=f"(x[3])
: "l"(s4));
}
DI void lds(double& x, double* addr) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<double*>(addr));
asm volatile("ld.shared.f64 {%0}, [%1];" : "=d"(x) : "l"(s1));
}
DI void lds(double (&x)[1], double* addr) {
auto s1 = __cvta_generic_to_shared(reinterpret_cast<double*>(addr));
asm volatile("ld.shared.f64 {%0}, [%1];" : "=d"(x[0]) : "l"(s1));
}
DI void lds(double (&x)[2], double* addr) {
auto* s2 = reinterpret_cast<double2*>(addr);
auto v2 = *s2;
x[0] = v2.x;
x[1] = v2.y;
auto s2 = __cvta_generic_to_shared(reinterpret_cast<double2*>(addr));
asm volatile("ld.shared.v2.f64 {%0, %1}, [%2];"
: "=d"(x[0]), "=d"(x[1])
: "l"(s2));
}
/** @} */

Expand Down

0 comments on commit b266d54

Please sign in to comment.