Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cpu-only is ok, still missing icdf
  • Loading branch information
ngxson committed Jun 1, 2025
commit 787f73fe3cc96a3785bcdbc97505ec748da10ad7
8 changes: 4 additions & 4 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1719,18 +1719,18 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
32 changes: 32 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,38 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

ggml_tensor * llm_graph_context::build_attn_reuse_cache(
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * kq_mask,
float kq_scale,
int il_reuse,
int il) const {
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);

// TODO @ngxson : this could be wrong
const auto * kv_state = hparams.is_swa(il_reuse) ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();

ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il_reuse);
ggml_tensor * v = kv_state->get_v(ctx0, il_reuse);

ggml_tensor * cur = build_attn_mha(gf, q, k, v, nullptr, kq_mask, nullptr, kq_scale);
cb(cur, "kqv_out", il);

if (wo) {
cur = build_lora_mm(wo, cur);
}

if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}

return cur;
}

llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);

Expand Down
11 changes: 11 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,17 @@ struct llm_graph_context {
float kq_scale,
int il) const;

// reuse cache from a previous layer, leaving no modifications to the cache
ggml_tensor * build_attn_reuse_cache(
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * kq_mask,
float kq_scale,
int il_reuse,
int il) const;

//
// recurrent
//
Expand Down
86 changes: 73 additions & 13 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8768,6 +8768,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
const int64_t n_embd_altup;
const int64_t n_altup;
const int i_altup_act;
const int n_layer_kv = 20; // number of layers having KV

ggml_tensor * one; // containing single element 1.0f

Expand Down Expand Up @@ -8821,6 +8822,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {

for (int il = 0; il < n_layer; ++il) {
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
const bool has_kv = (il < n_layer_kv);

const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
Expand All @@ -8841,7 +8843,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]

// self-attention
{
if (has_kv) {
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Expand Down Expand Up @@ -8877,25 +8879,43 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
cb(Qcur, "Qcur_pos", il);
cb(Kcur, "Kcur_pos", il);

// SOME LAYERS DOES NOT HAVE KV ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️

cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
} else {
// no KV layers
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);

Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
cb(Qcur, "Qcur_normed", il);

Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);

// TODO: slice the KQ mask to get only output tokens
const bool is_swa = hparams.is_swa(il);
const int il_reuse = n_layer_kv - (is_swa ? 2 : 1);
const auto & kq_mask = is_swa ? inp_attn->get_kq_mask_swa() : inp_attn->get_kq_mask();
// make sure the reused layer has the same SWA status as the current layer
GGML_ASSERT(
(is_swa && hparams.is_swa(il_reuse)) ||
(!is_swa && !hparams.is_swa(il_reuse))
);
cur = build_attn_reuse_cache(gf,
model.layers[il].wo, NULL,
Qcur, kq_mask, hparams.f_attention_scale, il_reuse, il);
}

cur = build_norm(cur,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);

// if (il == n_layer - 1) {
// // skip computing output for unused tokens
// ggml_tensor * inp_out_ids = build_inp_out_ids();
// cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
// }

cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
cb(cur, "attn_gated", il);

Expand Down Expand Up @@ -8967,7 +8987,41 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
inpL = cur;
}

cur = inpL;
cur = inpL; // [n_embd, n_tokens, n_altup]

// cur now has multiple altup(s), we want to merge them back to 1 altup
{
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
// do a view to skip the first slice (active altup)
ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
ggml_row_size(cur->type, n_embd),
ggml_row_size(cur->type, n_embd*n_tokens),
n_embd*n_tokens*ggml_element_size(cur));
ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
altup_unembd = ggml_div(ctx0,
ggml_mul(ctx0, altup_unembd, target_magnitude),
new_magnitude);
cb(altup_unembd, "altup_unembd", -1);

// equivalent to torch.mean(hidden_states, dim=0)
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
for (int i = 0; i < n_altup - 1; ++i) {
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
}
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
cb(cur, "unembd_merged", -1);
}

// cur now has shape: [n_embd, n_tokens]

// TODO @ngxson : move this to right after the last KV layer ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️
{
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
//inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}

cur = build_norm(cur,
model.output_norm, NULL,
Expand All @@ -8976,10 +9030,15 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
cb(cur, "result_norm", -1);
res->t_embd = cur;

// DUMMY ⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️
cur = view_2d_slice(cur, 0);
cur = build_lora_mm(model.tok_embd, cur);

{
// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}

cb(cur, "result_output", -1);
res->t_logits = cur;

Expand All @@ -8992,6 +9051,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {

// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int)x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
Expand Down