Skip to content

Latest commit

 

History

History

Example transformer models (decoder-only LLMs)

Here we provide a list of popular decoder-only LLMs composed via the transformer building blocks from this library. The main purpose is to demonstrate how to construct a new PyTorch LLM model from scratch using the AI Edge Torch Generative API, and convert it to TFLite format for on-device inference.

Gemma

Gemma is Google's open-source LLM. The model has both a 2B and 7B versions. See the model's Kaggle page. The example we provide is Gemma 2B, and the checkpoint for the PyTorch model can be downloaded from here.

PaliGemma

PaliGemma is a multimodal LLM which gets images and text as input, then generates text as output. See model's Kaggle page. The example we provide is PaliGemma 3B with 224 image size. Since Kaggle has only Jax-version of PaliGemma, PyTorch model can be download from here.

Note that PaliGemma can be converted to TfLite only with ODML Torch conversion backend

Llama

Llama 3.2 model is Meta's open-source LLM with 1B and 3B for text, and 11B and 90B for vision. The examples we provide are Llama 3.2 1B and 3B for text. The checkpoint can be found here.

TinyLlama

TinyLlama is a popular OSS smaller version of Meta's Llama2 model, with only 1.1B parameters. HuggingFace checkpoint.

Microsoft Phi-2 and 3.5-mini

Microsoft Phi-2 and Phi-3.5-mini are also decoder-only LLMs with 2.7B and 3.82B parameters each. See details on Kaggle for Phi-2 and HuggingFace for Phi-3.5-mini. Note that the example of Phi-3.5-mini supports up to 4K tokens, not to 128K tokens which the original Phi-3.5 supports.

Apple OpenELM

Apple OpenELM is also a decoder-only LLM with 270M, 450M, 1.1B, and 3B parameters. The example we provide is OpenELM 3B, and the checkpoint for the model can be found here.

HuggingFace SmolLM

HuggingFace SmolLM is also a decoder-only LLM with 135M, 360M, 1.7B parameters. The example we provide is SmolLM 135M, and the checkpoint for the model can be found here.

Qwen

Alibaba's Qwen 2.5 0.5B, 1B, 3B modes are also provided as examples.

AMD-Llama-135m

AMD-Llama-135m is a 135M parameter model based on the Llama2 architecture and uses the same tokenizer as Llama2. It was trained on AMD Instinct MI250 accelerators.

Overall workflow

To support a new LLM with the Edge Generative API, we need to go through the process of model (re)authoring, checkpoint mapping/loading, model quantization (via PT2E), model conversion to flatbuffer schema, model quality evaluation, benchmarking and on-device inference pipeline authoring.

Model (re)authoring

Model (re)authoring refers to the process of a few things:

  1. Understanding the overall model architecture (encoder-decoder, decoder-only etc).
  2. Compose the model using ai_edge_torch provided transformer building blocks. For each of the example models, we have a model definition file (e.g. tiny_llama/tiny_llama.py) where a nn.Module is defined, with its layers and a forward function. There is also a get_model_config function which returns a ModelConfig instance with hyper-parameters such as embedding size, layer count etc. Finally, there is a define_and_run function which builds the model instance, and runs the forward pass with a few sample inputs.

Here we use TinyLlama as an example to walk you through the authoring steps.

Define model's structure

The model structure of TinyLlama can be found by instantiating the pretrained model.

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
print(model)
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)

Based on the original model structure, construct a new nn.Module model using the AI Edge Torch Generative API. As many examples do, either use DecoderOnlyModel class as is like SmolLM, or inherit DecoderOnlyModel class then modify only some component like Llama 3.2, or construct entirely a new nn.Module from scratch like Gemma 2.

Here is an example of TinyLlama constructed from scratch.

class TinyLLamma(nn.Module):
def __init__(self, config: cfg.ModelConfig):
super().__init__()
self.config = config
# Construct model layers.
self.lm_head = nn.Linear(
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
)
self.tok_embedding = nn.Embedding(
config.vocab_size, config.embedding_dim, padding_idx=0
)
self.transformer_blocks = nn.ModuleList(
TransformerBlock(config) for _ in range(config.num_layers)
)
self.final_norm = builder.build_norm(
config.embedding_dim,
config.final_norm_config,
)
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(config.attn_config.rotary_percentage * config.head_dim),
base=10_000,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
)
self.config = config

Define model's forward function

# The model's forward function takes in additional k/v cache tensors
# and returns the updated k/v cache tensors to the caller.
# This can be eliminated if we handle k/v cache updates inside the model itself.
@torch.inference_mode
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
B, T = idx.size()
assert (
self.config.max_seq_len >= T
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
cos, sin = self.rope_cache
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]
# forward the model itself
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
for i, block in enumerate(self.transformer_blocks):
x = block(x, (cos, sin), mask, input_pos)
x = self.final_norm(x)
res = self.lm_head(x) # (b, t, vocab_size)
return res

Set model's hyper parameters

It's nice to have the new nn.Module configurable.

def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
attn_config = cfg.AttentionConfig(
num_heads=32,
num_query_groups=4,
rotary_percentage=1.0,
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationType.SILU,
intermediate_size=5632,
)
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
config = cfg.ModelConfig(
vocab_size=32000,
num_layers=22,
max_seq_len=2048,
embedding_dim=2048,
kv_cache_max_len=kv_cache_max_len,
attn_config=attn_config,
ff_config=ff_config,
pre_attention_norm_config=norm_config,
pre_ff_norm_config=norm_config,
final_norm_config=norm_config,
enable_hlfb=True,
)
return config

These config values can be found in config.json of the Kaggle page.

Now, you will have an nn.Module named TinyLlama, the next step is to restore the weights from original checkpoint into the new model.

Checkpoint mapping/loading

After the model is defined, we need to load the original trained weights to the new model. This is needed because the state_dict of the new model will be different from the original model's state_dict. There are helper functions in place to simplify the state_dict mapping process (utilities/loader.py). The user needs to provide a layer name template (TensorNames) for the source model. For TinyLlama, layer names can be found from the SafeTensors file.

import ai_edge_torch.generative.utilities.loader as loading_utils

safetensors = loading_utils.load_safetensors("path_to_checkpoint")
print(safetensors.keys())
dict_keys(['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', ...])

This template is then used to create an updated state_dict that works with the mapped model. The tensor map includes the following fields:

@dataclass
class TensorNames:
attn_query_proj: str
attn_key_proj: str
attn_value_proj: str
attn_output_proj: str
ff_up_proj: str
ff_down_proj: str
ff_gate_proj: str = None
pre_attn_norm: str = None
pre_ff_norm: str = None
embedding: str = None
final_norm: str = None
lm_head: str = None

The fields that have a default value of None are optional and should only be populated if they are relevant to the model architecture. For TinyLlama, we will define the following TENSOR_NAMES:

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
attn_query_proj="model.layers.{}.self_attn.q_proj",
attn_key_proj="model.layers.{}.self_attn.k_proj",
attn_value_proj="model.layers.{}.self_attn.v_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
pre_attn_norm="model.layers.{}.input_layernorm",
pre_ff_norm="model.layers.{}.post_attention_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
lm_head="lm_head",
)

With the TensorNames defined, a user can simply use the loading utils to load an instance of the mapped model. For instance:

model = MappedModel(config)
loader = loading_utils.ModelLoader("path_to_checkpoint", TENSOR_NAMES)
loader.load(model)

Currently, ModelLoader supports PyTorch state dictionary and SafeTensors checkpoints. We recommend testing the mapped model against your reference implementation using a few input samples before proceeding to the conversion step.

Verify (re)authored model

Once the model (re)authoring is completed, it should be verified if it generates the output close to one from the original model. Generative API provides some utilities to make it easy to verify models as shown with verify.py in each example folder.

To instantiate the original models, verify.py imports kagglehub and/or transformers which may require user authentication tokens to download the original models. Please refer Kagglehub page or HuggingFace page about how to set user authentication tokens up.

To verify Gemma models, it requires to install gemma_pytorch package from its github repository.

pip install -q -U immutabledict sentencepiece
git clone https://github.com/google/gemma_pytorch.git
export PYTHONPATH=$PWD/gemma_pytorch:$PYTHONPATH

Model conversion

In this step, we use the ai_edge_torch's standard multi-signature conversion API to convert PyTorch nn.Module to a single TFLite flatbuffer for on-device execution. For example, in tiny_llama/convert_to_tflite.py, we use this python code to convert the TinyLlama model to a multi-signature TFLite model:

def convert_tiny_llama_to_tflite(
checkpoint_path: str,
prefill_seq_len: int = 512,
kv_cache_max_len: int = 1024,
quantize: bool = True,
):
"""An example method for converting TinyLlama model to multi-signature
tflite model.
Args:
checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
Defaults to 512.
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
including both prefill and decode. Defaults to 1024.
quantize (bool, optional): Whether the model should be quanized.
Defaults to True.
"""
pytorch_model = tiny_llama.build_model(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
)
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
.convert(quant_config=quant_config)
)
edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')

Once converted, you will get a .tflite model which will be ready for on-device execution. Note that the .tflite model generated uses static shapes. Inside the generated .tflite model, there will be two signatures defined (two entrypoints to the model):

  1. prefill: taking 2 tensor inputs prefill_tokens, prefill_input_pos. With shape (BATCH_SIZE, PREFILL_SEQ_LEN) and (PREFILL_SEQ_LEN).
  2. decode: taking 2 tensor inputs decode_token, decode_input_pos. With shape (1, 1) and (1). To learn more about TFLite signatures, please refer to this article.

Model quantization

To apply quantization, we need to create a configuration that fully expresses how the model should be quantized. This configuration is then passed into conversion, generating a quantized model.

quantize/quant_recipes.py contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from generative/examples/quantize/example.py.

quant_config = quant_recipes.full_int8_dynamic_recipe()
edge_model = ai_edge_torch.convert(
    model, (tokens, input_pos), quant_config=quant_config
)

Once converted, you will get a quantized .tflite model which will be ready for on-device execution.