BEATs: Audio Pre-Training with Acoustic Tokenizers
Official PyTorch implementation and pretrained models of BEATs
Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2 |
---|---|---|---|---|
Iter1 | Random Projection | BEATs_iter1 | Fine-tuned BEATs_iter1 (cpt1) | Fine-tuned BEATs_iter1 (cpt2) |
Iter2 | Tokenizer_iter2 | BEATs_iter2 | Fine-tuned BEATs_iter2 (cpt1) | Fine-tuned BEATs_iter2 (cpt2) |
Iter3 | Tokenizer_iter3 | BEATs_iter3 | Fine-tuned BEATs_iter3 (cpt1) | Fine-tuned BEATs_iter3 (cpt2) |
Iter3+ | Tokenizer_iter3+ (AS20K) | BEATs_iter3+ (AS20K) | Fine-tuned BEATs_iter3+ (AS20K) (cpt1) | Fine-tuned BEATs_iter3+ (AS20K) (cpt2) |
Iter3+ | Tokenizer_iter3+ (AS2M) | BEATs_iter3+ (AS2M) | Fine-tuned BEATs_iter3+ (AS2M) (cpt1) | Fine-tuned BEATs_iter3+ (AS2M) (cpt2) |
import torch
from Tokenizers import TokenizersConfig, Tokenizers
# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/tokenizer.pt')
cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()
# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()
labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)
import torch
from BEATs import BEATs, BEATsConfig
# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/model.pt')
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()
# extract the the audio representation
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()
representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
import torch
from BEATs import BEATs, BEATsConfig
# load the fine-tuned checkpoints
checkpoint = torch.load('/path/to/model.pt')
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()
# predict the classification probability of each class
audio_input_16khz = torch.randn(3, 10000)
padding_mask = torch.zeros(3, 10000).bool()
probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]
for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')
This project is licensed under the license found in the LICENSE file in the root directory of this source tree. Portions of the source code are based on the FAIRSEQ and VQGAN project.
Microsoft Open Source Code of Conduct
If you find our work is useful in your research, please cite the following paper:
@article{Chen2022beats,
title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
eprint={2212.09058},
archivePrefix={arXiv},
year={2022}
}
For help or issues using BEATs models, please submit a GitHub issue.
For other communications related to BEATs, please contact Yu Wu ([email protected]
).