Last active
January 15, 2025 03:29
-
-
Save ruvnet/0928768dd1e4af8816e31dde0a0205d5 to your computer and use it in GitHub Desktop.
A PyTorch implementation of a Mixture of Experts (MoE) model resembling the Mixtral 8x7B architecture, with detailed inline comments. This model combines transformer layers with an MoE layer consisting of 8 experts, aiming for high efficiency by activating only 2 experts per token. It's configured with dimensions reflecting the operational effic…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This model integrates the MoE concept within a Transformer architecture. Each token's | |
representation is processed by a subset of experts, determined by the gating mechanism. | |
This architecture allows for efficient and specialized handling of different aspects of the | |
data, aiming for the adaptability and efficiency noted in the Mixtral 8x7B model's design | |
philosophy. The model activates only a fraction of the available experts for each token, | |
significantly reducing the computational resources needed compared to activating all experts | |
for all tokens. | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Define the Expert class | |
class Expert(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim): | |
super(Expert, self).__init__() | |
self.fc1 = nn.Linear(input_dim, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
x = F.relu(self.fc1(x)) | |
return self.fc2(x) | |
# Define the Gating Network class | |
class GatingNetwork(nn.Module): | |
def __init__(self, input_dim, num_experts): | |
super(GatingNetwork, self).__init__() | |
self.gate = nn.Linear(input_dim, num_experts) | |
def forward(self, x): | |
return F.softmax(self.gate(x), dim=2) | |
# Define the Mixture of Experts Layer class | |
class MoELayer(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim, num_experts): | |
super(MoELayer, self).__init__() | |
self.experts = nn.ModuleList([Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]) | |
self.gate = GatingNetwork(input_dim, num_experts) | |
def forward(self, x, num_experts_per_tok): | |
gating_scores = self.gate(x) | |
topk_gating_scores, topk_indices = gating_scores.topk(num_experts_per_tok, dim=2, sorted=False) | |
# Create a mask to zero out the contributions of non-topk experts | |
mask = torch.zeros_like(gating_scores).scatter_(2, topk_indices, 1) | |
# Use the mask to retain only the topk gating scores | |
gating_scores = gating_scores * mask | |
# Normalize the gating scores to sum to 1 across the selected top experts | |
gating_scores = F.normalize(gating_scores, p=1, dim=2) | |
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) | |
expert_outputs = expert_outputs.transpose(1, 2) | |
output = torch.einsum('bte,bteo->bto', gating_scores, expert_outputs) | |
return output | |
# Define the overall Transformer model with integrated MoE | |
class TransformerWithMoE(nn.Module): | |
def __init__(self, num_layers, dim, head_dim, hidden_dim, n_heads, num_experts, vocab_size, num_experts_per_tok): | |
super(TransformerWithMoE, self).__init__() | |
self.num_experts_per_tok = num_experts_per_tok | |
self.embedding = nn.Embedding(vocab_size, dim) | |
self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads) for _ in range(num_layers)]) | |
self.moe_layer = MoELayer(dim, hidden_dim, dim, num_experts) | |
self.output_layer = nn.Linear(dim, vocab_size) | |
def forward(self, x): | |
x = self.embedding(x) | |
for layer in self.layers: | |
x = layer(x) | |
x = self.moe_layer(x, self.num_experts_per_tok) | |
logits = self.output_layer(x) | |
return logits | |
# Initialize the model with configurations matching Mixtral 8x7B | |
model = TransformerWithMoE( | |
num_layers=32, # Number of transformer layers | |
dim=4096, # Dimension of the model | |
head_dim=128, # Dimension of each head in the multi-head attention mechanisms | |
hidden_dim=14336, # Hidden dimensionality in the feed-forward network within the transformer | |
n_heads=32, # Number of attention heads | |
num_experts=8, # Number of experts in the MoE layer | |
vocab_size=32000, # Vocabulary size for the embedding layer | |
num_experts_per_tok=2 # Number of experts activated per token | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Good work on the clean Code. I do have few queries/recommendation below
https://mistral.ai/news/mixtral-of-experts/
From the Mixtral Release Blog: "Mixtral is a sparse mixture-of-experts network. It is a decoder-only model where the feedforward block picks from a set of 8 distinct groups of parameters. At every layer, for every token, a router network chooses two of these groups (the “experts”) to process the token and combine their output additively."
We need to Use Decoder Transformer Architecture instead of Encoder.
Sparse Mixture of Experts is applied at "every layer" of the transformer block, In the above code, We are applying the self.moe_layer() only once after all the block layers are processed by the transformer block.