This repository contains a custom PyTorch CUDA kernel for batched windowed matrix multiplication. This is particularly useful for the windowed local attention in sparse attention transformer models such as BigBird and Longformer. Given two matrices
To complete the windowed attention operation, the attention matrix
Be sure to have the cudatoolkit installed before running pip install. We recommend installing the cudatoolkit using conda.
conda install -c nvidia cuda-toolkit
To install the package run
pip install git+https://github.com/webis-de/pytorch-window-matmul.git
An example on how to use the kernel:
import torch
import window_matmul
# create some random matrices
batch_size = 2
seq_len = 10
hidden_size = 5
window_size = 2
q = torch.rand(batch_size, seq_len, hidden_size)
k = torch.rand(batch_size, hidden_size, seq_len)
v = torch.rand(batch_size, seq_len, hidden_size)
# compute windowed attention
a = window_matmul.window_matmul(q, k, window_size)
assert a.shape[-1] == 2 * window_size + 1
# compute output
o = window_matmul.unwindow_matmul(a, v, window_size)
NOTE: The cpu version is not optimized and is only for reference. The cuda version is optimized and is the one to use.