|
| 1 | +/*! |
| 2 | + * Copyright (c) 2015 by Contributors |
| 3 | + * \file cuda_utils.h |
| 4 | + * \brief CUDA debugging utilities. |
| 5 | + */ |
| 6 | +#ifndef MXNET_COMMON_CUDA_UTILS_H_ |
| 7 | +#define MXNET_COMMON_CUDA_UTILS_H_ |
| 8 | + |
| 9 | +#include <dmlc/logging.h> |
| 10 | +#include <mshadow/base.h> |
| 11 | + |
| 12 | +/*! \brief Macros/inlines to assist CLion to parse Cuda files (*.cu, *.cuh) */ |
| 13 | +#ifdef __JETBRAINS_IDE__ |
| 14 | +#define __CUDACC__ 1 |
| 15 | +#define __host__ |
| 16 | +#define __device__ |
| 17 | +#define __global__ |
| 18 | +#define __forceinline__ |
| 19 | +#define __shared__ |
| 20 | +inline void __syncthreads() {} |
| 21 | +inline void __threadfence_block() {} |
| 22 | +template<class T> inline T __clz(const T val) { return val; } |
| 23 | +struct __cuda_fake_struct { int x; int y; int z; }; |
| 24 | +extern __cuda_fake_struct blockDim; |
| 25 | +extern __cuda_fake_struct threadIdx; |
| 26 | +extern __cuda_fake_struct blockIdx; |
| 27 | +#endif |
| 28 | + |
| 29 | +#if MXNET_USE_CUDA |
| 30 | + |
| 31 | +#include <cuda_runtime.h> |
| 32 | +#include <cublas_v2.h> |
| 33 | +#include <curand.h> |
| 34 | + |
| 35 | +namespace mxnet { |
| 36 | +namespace common { |
| 37 | +/*! \brief common utils for cuda */ |
| 38 | +namespace cuda { |
| 39 | +/*! |
| 40 | + * \brief Get string representation of cuBLAS errors. |
| 41 | + * \param error The error. |
| 42 | + * \return String representation. |
| 43 | + */ |
| 44 | +inline const char* CublasGetErrorString(cublasStatus_t error) { |
| 45 | + switch (error) { |
| 46 | + case CUBLAS_STATUS_SUCCESS: |
| 47 | + return "CUBLAS_STATUS_SUCCESS"; |
| 48 | + case CUBLAS_STATUS_NOT_INITIALIZED: |
| 49 | + return "CUBLAS_STATUS_NOT_INITIALIZED"; |
| 50 | + case CUBLAS_STATUS_ALLOC_FAILED: |
| 51 | + return "CUBLAS_STATUS_ALLOC_FAILED"; |
| 52 | + case CUBLAS_STATUS_INVALID_VALUE: |
| 53 | + return "CUBLAS_STATUS_INVALID_VALUE"; |
| 54 | + case CUBLAS_STATUS_ARCH_MISMATCH: |
| 55 | + return "CUBLAS_STATUS_ARCH_MISMATCH"; |
| 56 | + case CUBLAS_STATUS_MAPPING_ERROR: |
| 57 | + return "CUBLAS_STATUS_MAPPING_ERROR"; |
| 58 | + case CUBLAS_STATUS_EXECUTION_FAILED: |
| 59 | + return "CUBLAS_STATUS_EXECUTION_FAILED"; |
| 60 | + case CUBLAS_STATUS_INTERNAL_ERROR: |
| 61 | + return "CUBLAS_STATUS_INTERNAL_ERROR"; |
| 62 | + case CUBLAS_STATUS_NOT_SUPPORTED: |
| 63 | + return "CUBLAS_STATUS_NOT_SUPPORTED"; |
| 64 | + default: |
| 65 | + break; |
| 66 | + } |
| 67 | + return "Unknown cuBLAS status"; |
| 68 | +} |
| 69 | + |
| 70 | +/*! |
| 71 | + * \brief Get string representation of cuRAND errors. |
| 72 | + * \param status The status. |
| 73 | + * \return String representation. |
| 74 | + */ |
| 75 | +inline const char* CurandGetErrorString(curandStatus_t status) { |
| 76 | + switch (status) { |
| 77 | + case CURAND_STATUS_SUCCESS: |
| 78 | + return "CURAND_STATUS_SUCCESS"; |
| 79 | + case CURAND_STATUS_VERSION_MISMATCH: |
| 80 | + return "CURAND_STATUS_VERSION_MISMATCH"; |
| 81 | + case CURAND_STATUS_NOT_INITIALIZED: |
| 82 | + return "CURAND_STATUS_NOT_INITIALIZED"; |
| 83 | + case CURAND_STATUS_ALLOCATION_FAILED: |
| 84 | + return "CURAND_STATUS_ALLOCATION_FAILED"; |
| 85 | + case CURAND_STATUS_TYPE_ERROR: |
| 86 | + return "CURAND_STATUS_TYPE_ERROR"; |
| 87 | + case CURAND_STATUS_OUT_OF_RANGE: |
| 88 | + return "CURAND_STATUS_OUT_OF_RANGE"; |
| 89 | + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: |
| 90 | + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; |
| 91 | + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: |
| 92 | + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; |
| 93 | + case CURAND_STATUS_LAUNCH_FAILURE: |
| 94 | + return "CURAND_STATUS_LAUNCH_FAILURE"; |
| 95 | + case CURAND_STATUS_PREEXISTING_FAILURE: |
| 96 | + return "CURAND_STATUS_PREEXISTING_FAILURE"; |
| 97 | + case CURAND_STATUS_INITIALIZATION_FAILED: |
| 98 | + return "CURAND_STATUS_INITIALIZATION_FAILED"; |
| 99 | + case CURAND_STATUS_ARCH_MISMATCH: |
| 100 | + return "CURAND_STATUS_ARCH_MISMATCH"; |
| 101 | + case CURAND_STATUS_INTERNAL_ERROR: |
| 102 | + return "CURAND_STATUS_INTERNAL_ERROR"; |
| 103 | + } |
| 104 | + return "Unknown cuRAND status"; |
| 105 | +} |
| 106 | + |
| 107 | +} // namespace cuda |
| 108 | +} // namespace common |
| 109 | +} // namespace mxnet |
| 110 | + |
| 111 | +/*! |
| 112 | + * \brief Check CUDA error. |
| 113 | + * \param msg Message to print if an error occured. |
| 114 | + */ |
| 115 | +#define CHECK_CUDA_ERROR(msg) \ |
| 116 | + { \ |
| 117 | + cudaError_t e = cudaGetLastError(); \ |
| 118 | + CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \ |
| 119 | + } |
| 120 | + |
| 121 | +/*! |
| 122 | + * \brief Protected CUDA call. |
| 123 | + * \param func Expression to call. |
| 124 | + * |
| 125 | + * It checks for CUDA errors after invocation of the expression. |
| 126 | + */ |
| 127 | +#define CUDA_CALL(func) \ |
| 128 | + { \ |
| 129 | + cudaError_t e = (func); \ |
| 130 | + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ |
| 131 | + << "CUDA: " << cudaGetErrorString(e); \ |
| 132 | + } |
| 133 | + |
| 134 | +/*! |
| 135 | + * \brief Protected cuBLAS call. |
| 136 | + * \param func Expression to call. |
| 137 | + * |
| 138 | + * It checks for cuBLAS errors after invocation of the expression. |
| 139 | + */ |
| 140 | +#define CUBLAS_CALL(func) \ |
| 141 | + { \ |
| 142 | + cublasStatus_t e = (func); \ |
| 143 | + CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ |
| 144 | + << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \ |
| 145 | + } |
| 146 | + |
| 147 | +/*! |
| 148 | + * \brief Protected cuRAND call. |
| 149 | + * \param func Expression to call. |
| 150 | + * |
| 151 | + * It checks for cuRAND errors after invocation of the expression. |
| 152 | + */ |
| 153 | +#define CURAND_CALL(func) \ |
| 154 | + { \ |
| 155 | + curandStatus_t e = (func); \ |
| 156 | + CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ |
| 157 | + << "cuRAND: " << common::cuda::CurandGetErrorString(e); \ |
| 158 | + } |
| 159 | + |
| 160 | +#endif // MXNET_USE_CUDA |
| 161 | + |
| 162 | +#if MXNET_USE_CUDNN |
| 163 | + |
| 164 | +#include <cudnn.h> |
| 165 | + |
| 166 | +#define CUDNN_CALL(func) \ |
| 167 | + { \ |
| 168 | + cudnnStatus_t e = (func); \ |
| 169 | + CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \ |
| 170 | + } |
| 171 | + |
| 172 | +#endif // MXNET_USE_CUDNN |
| 173 | + |
| 174 | +// Overload atomicAdd to work for floats on all architectures |
| 175 | +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 |
| 176 | +// From CUDA Programming Guide |
| 177 | +static inline __device__ void atomicAdd(double *address, double val) { |
| 178 | + unsigned long long* address_as_ull = // NOLINT(*) |
| 179 | + reinterpret_cast<unsigned long long*>(address); // NOLINT(*) |
| 180 | + unsigned long long old = *address_as_ull; // NOLINT(*) |
| 181 | + unsigned long long assumed; // NOLINT(*) |
| 182 | + |
| 183 | + do { |
| 184 | + assumed = old; |
| 185 | + old = atomicCAS(address_as_ull, assumed, |
| 186 | + __double_as_longlong(val + |
| 187 | + __longlong_as_double(assumed))); |
| 188 | + |
| 189 | + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) |
| 190 | + } while (assumed != old); |
| 191 | +} |
| 192 | +#endif |
| 193 | + |
| 194 | +// Overload atomicAdd for half precision |
| 195 | +// Taken from: |
| 196 | +// https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh |
| 197 | +#if defined(__CUDA_ARCH__) |
| 198 | +static inline __device__ void atomicAdd(mshadow::half::half_t *address, |
| 199 | + mshadow::half::half_t val) { |
| 200 | + unsigned int *address_as_ui = |
| 201 | + reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - |
| 202 | + (reinterpret_cast<size_t>(address) & 2)); |
| 203 | + unsigned int old = *address_as_ui; |
| 204 | + unsigned int assumed; |
| 205 | + |
| 206 | + do { |
| 207 | + assumed = old; |
| 208 | + mshadow::half::half_t hsum; |
| 209 | + hsum.half_ = |
| 210 | + reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff); |
| 211 | + hsum += val; |
| 212 | + old = reinterpret_cast<size_t>(address) & 2 |
| 213 | + ? (old & 0xffff) | (hsum.half_ << 16) |
| 214 | + : (old & 0xffff0000) | hsum.half_; |
| 215 | + old = atomicCAS(address_as_ui, assumed, old); |
| 216 | + } while (assumed != old); |
| 217 | +} |
| 218 | +#endif |
| 219 | + |
| 220 | +#endif // MXNET_COMMON_CUDA_UTILS_H_ |
0 commit comments