Skip to content

Latest commit

 

History

History
 
 

beats

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 

BEATs

BEATs: Audio Pre-Training with Acoustic Tokenizers

Official PyTorch implementation and pretrained models of BEATs

Pre-Trained and Fine-Tuned Tokenizers and Models

Iterations Tokenizer Pre-Trained Model AudioSet Fine-Tuned Model 1 AudioSet Fine-Tuned Model 2
Iter1 Random Projection BEATs_iter1 Fine-tuned BEATs_iter1 (cpt1) Fine-tuned BEATs_iter1 (cpt2)
Iter2 Tokenizer_iter2 BEATs_iter2 Fine-tuned BEATs_iter2 (cpt1) Fine-tuned BEATs_iter2 (cpt2)
Iter3 Tokenizer_iter3 BEATs_iter3 Fine-tuned BEATs_iter3 (cpt1) Fine-tuned BEATs_iter3 (cpt2)
Iter3+ Tokenizer_iter3+ (AS20K) BEATs_iter3+ (AS20K) Fine-tuned BEATs_iter3+ (AS20K) (cpt1) Fine-tuned BEATs_iter3+ (AS20K) (cpt2)
Iter3+ Tokenizer_iter3+ (AS2M) BEATs_iter3+ (AS2M) Fine-tuned BEATs_iter3+ (AS2M) (cpt1) Fine-tuned BEATs_iter3+ (AS2M) (cpt2)

Load Tokenizers

import torch
from Tokenizers import TokenizersConfig, Tokenizers

# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/tokenizer.pt')

cfg = TokenizersConfig(checkpoint['cfg'])
BEATs_tokenizer = Tokenizers(cfg)
BEATs_tokenizer.load_state_dict(checkpoint['model'])
BEATs_tokenizer.eval()

# tokenize the audio and generate the labels
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()

labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask)

Load Pre-Trained Models

import torch
from BEATs import BEATs, BEATsConfig

# load the pre-trained checkpoints
checkpoint = torch.load('/path/to/model.pt')

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# extract the the audio representation
audio_input_16khz = torch.randn(1, 10000)
padding_mask = torch.zeros(1, 10000).bool()

representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

Load Fine-tuned Models

import torch
from BEATs import BEATs, BEATsConfig

# load the fine-tuned checkpoints
checkpoint = torch.load('/path/to/model.pt')

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# predict the classification probability of each class
audio_input_16khz = torch.randn(3, 10000)
padding_mask = torch.zeros(3, 10000).bool()

probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))):
    top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx]
    print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}')

Evaluation Results

Comparing with the SOTA Single Models

alt text

Comparing with the SOTA Ensemble Models

alt text

Comparing Different BEATS Tokenizers

alt text

Comparing Different Pre-Training Targets

alt text

License

This project is licensed under the license found in the LICENSE file in the root directory of this source tree. Portions of the source code are based on the FAIRSEQ and VQGAN project.

Microsoft Open Source Code of Conduct

Reference

If you find our work is useful in your research, please cite the following paper:

@article{Chen2022beats,
  title = {BEATs: Audio Pre-Training with Acoustic Tokenizers},
  author  = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei},
  eprint={2212.09058},
  archivePrefix={arXiv},
  year={2022}
}

Contact Information

For help or issues using BEATs models, please submit a GitHub issue.

For other communications related to BEATs, please contact Yu Wu ([email protected]).