Complete implementation of ViT (Vision Transformer) model using Triton kernels. The weights are imported from HF (HuggingFace). We have also compared the implementation against the official HF implementation. The results are in the benchmarks
folder. You can use this ViT implementation as an education resource or in your pipeline. This implementation is completely valid and functional. It only supports forward passes for now.
Some accompanying posts can help you get started with GPU programming:
[Current]
This repo has 3 goals:
- Learn GPU programming by learning how to write Triton kernels
- Learn to use triton kernels in a model by implementing an architecture using PyTorch but by calling custom triton kernels
- Learn to load model weights to a custom implementation from a different repository like HuggingFace
[Future]
This repo aims to become a standalone pip installable package for using ViT in the most optimized way. To reach there, we would need the following optimizations
- Faster Conv1D implementation - This should help beat HF implementation at all batch sizes
- Add Flash attn - This should significantly improve the time
- Given a batch size, fix all the tensor sizes
- Use CUDA graphs to optimize kernel dispatch time
- Support for other ViT flavors
This should help it become a really fast and optimized ViT implementation that can be used as an image encoder in all the multimodal models during inference or training of LLM. Training/Fine-tuning of ViT is currently out of the scope for this repo.
If you'd like to test this implementation on your machine, all you need to do is,
git clone https://github.com/cmeraki/vit.triton.git
cd vit.triton
python -m venv .venv
source ~/.venv/bin/activate
pip install -r requirements.txt # Requriements are suited for NVIDIA GPU and linux setup
python -m vit.vit # This will run the benchmarking on both HF implementation of ViT and the custom implementation
- vit/
- kernels/ # All triton kernels reside here
- load_weights.py # Functions for loading weights from HF
- utils.py # Utils
- vit.py # Architecture written in torch, but calling triton kernels
- benchmarks/ # Benchmark results
- examples/ # Small examples used in posts
TBA
The benchmarks were run on Nvidia 3080 Ti Founders edition.
The results of benchmarks comparing custom Triton implementation with PyTorch operations are as follows (higher the better):
add
layernorm
matmul
matmul3
softmax
conv2d
The results of benchmarks comparing custom ViT implementation with HF are as follows (the lower the better):
Currently, the HuggingFace implementation is faster, but I am working on improving it to be faster than HuggingFace! Stay tuned.
In case you have any questions or suggestions, feel free to raise an issue!