Skip to content

Commit 848e7ac

Browse files
pytorchbotdrisspg
andauthored
[SDPA-CUDNN] Make CuDNN Attention Opt in (#138587)
[SDPA-CUDNN] Make CuDNN Attention Opt in (#138522) # Summary Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5: 1. #138529 2. huggingface/diffusers#9704 3. #138354 In light of the above we are going to make the CuDNN backend Opt-in by default. This can be done easily with the context manager for choosing backends I.e.: ``` Python from torch.nn.attention import sdpa_kernel, SDPBackend with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): out = F.scaled_dot_product_attention(q, k, v) ``` This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). Cc @atalman Pull Request resolved: #138522 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet (cherry picked from commit 9a9a0ab) Co-authored-by: drisspg <[email protected]>
1 parent 885c823 commit 848e7ac

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,11 @@ bool check_prefer_cudnn_attention() {
6868
std::array<SDPBackend, num_backends> priority_order(sdp_params const& params) {
6969
constexpr std::array<SDPBackend, num_backends> default_order{
7070
SDPBackend::flash_attention,
71-
SDPBackend::cudnn_attention,
7271
SDPBackend::efficient_attention,
73-
SDPBackend::math};
74-
constexpr std::array<SDPBackend, num_backends> cudnn_order{
72+
SDPBackend::math,
7573
SDPBackend::cudnn_attention,
76-
SDPBackend::flash_attention,
77-
SDPBackend::efficient_attention,
78-
SDPBackend::math};
79-
static const bool prefer_cudnn = check_prefer_cudnn_attention();
80-
return prefer_cudnn ? cudnn_order : default_order;
74+
};
75+
return default_order;
8176
}
8277

8378
bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) {

test/test_transformers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2809,8 +2809,12 @@ def test_fused_sdp_choice(self, device, type: str):
28092809
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
28102810
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
28112811

2812+
# TODO we are currently disabling this by default, lets assert that this returns
2813+
# FlashAttention, we need to change when we make remove opt-in for cudnn
28122814
if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
2813-
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
2815+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
2816+
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
2817+
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value)
28142818
elif PLATFORM_SUPPORTS_FLASH_ATTENTION:
28152819
self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value)
28162820
elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows

0 commit comments

Comments
 (0)