Skip to content

Instantly share code, notes, and snippets.

@Chillee
Created August 14, 2024 17:38
Show Gist options
  • Save Chillee/eb339f9ff2261616e563ea7684d1df0c to your computer and use it in GitHub Desktop.
Save Chillee/eb339f9ff2261616e563ea7684d1df0c to your computer and use it in GitHub Desktop.
FlexAttention examples
import torch.nn as nn
import copy
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, or_masks, create_mask
from triton.testing import do_bench
from functools import partial
torch.set_default_device('cuda')
B = 4
q, k, v = [torch.randn(4, 16, 4096, 64, requires_grad=True, dtype=torch.float16) for _ in range(3)]
def causal_mask(b, h, q, kv):
return q >= kv
def prefix_full(b, h, q, kv, prefix_lengths):
return kv <= prefix_lengths[b]
short_prefixes = torch.randint(512, 1024, (B,), dtype=torch.int)
short_prefix_mask = create_block_mask(or_masks(causal_mask, partial(prefix_full, prefix_lengths=short_prefixes)), B, None, 4096, 4096)
long_prefixes = torch.randint(2048, 2048+1024, (B,), dtype=torch.int)
long_prefix_mask = create_block_mask(or_masks(causal_mask, partial(prefix_full, prefix_lengths=long_prefixes)), B, None, 4096, 4096)
print("short prefixes: ", short_prefix_mask)
flex_attention = torch.compile(flex_attention)
print("short prefixLM: ", do_bench(lambda: flex_attention(q, k, v, block_mask=short_prefix_mask).sum().backward()))
mask = create_mask(or_masks(causal_mask, partial(prefix_full, prefix_lengths=short_prefixes)), B, 1, 4096, 4096)
print("xformers/sdpa with mask: ", do_bench(lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask).sum().backward()))
print("FA (full): ", do_bench(lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v).sum().backward()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment