Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add MultiQueryAttentionBlock
  • Loading branch information
ngxson committed Jun 23, 2025
commit 325cbe761c274c2e50f5cd96e53acb0bc2bbb25a
128 changes: 123 additions & 5 deletions tools/mtmd/clip-mobilenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,14 @@ struct v5_uir : v5_blk {
int dw_kernel_size_start = 0;
int dw_kernel_size_mid = 0;
bool multiscale = false;
ggml_tensor * layer_scale = nullptr;

v5_cna dw_start;
v5_cna pw_exp;
v5_cna dw_mid;
v5_cna pw_proj;

ggml_tensor * layer_scale = nullptr;

v5_uir(int dw_kernel_size_start, int dw_kernel_size_mid, int filters, int stride = 1, float expand_ratio = 4.0f, bool multiscale = false) :
dw_kernel_size_start(dw_kernel_size_start),
dw_kernel_size_mid(dw_kernel_size_mid),
Expand Down Expand Up @@ -244,17 +245,126 @@ struct v5_mmqa : v5_blk {
bool mmqa_avg_pool_kv = false;
bool multiscale = false;

ggml_tensor * k_down_conv = nullptr;
ggml_tensor * k_norm = nullptr;
ggml_tensor * k_proj = nullptr;
ggml_tensor * q_proj = nullptr;
ggml_tensor * v_down_conv = nullptr;
ggml_tensor * v_norm = nullptr;
ggml_tensor * v_proj = nullptr;
ggml_tensor * o_proj = nullptr;
ggml_tensor * layer_scale = nullptr;
ggml_tensor * norm = nullptr;

v5_mmqa(int num_heads, int kv_dim, int kv_strides,
bool mmqa_avg_pool_kv = false, bool multiscale = false) :
num_heads(num_heads), kv_strides(kv_strides), kv_dim(kv_dim),
mmqa_avg_pool_kv(mmqa_avg_pool_kv), multiscale(multiscale) {}

virtual void load_tensors(const std::string & prefix, get_tensor_fn & get_tensor) {
// TODO
if (kv_strides > 1) {
k_down_conv = get_tensor(str_concat(prefix, ".attn.key.down_conv.weight"));
v_down_conv = get_tensor(str_concat(prefix, ".attn.value.down_conv.weight"));
k_norm = get_tensor(str_concat(prefix, ".attn.key.norm.weight"));
v_norm = get_tensor(str_concat(prefix, ".attn.value.norm.weight"));
}
k_proj = get_tensor(str_concat(prefix, ".attn.key.proj.weight"));
q_proj = get_tensor(str_concat(prefix, ".attn.query.proj.weight"));
v_proj = get_tensor(str_concat(prefix, ".attn.value.proj.weight"));
o_proj = get_tensor(str_concat(prefix, ".attn.output.proj.weight"));
layer_scale = get_tensor(str_concat(prefix, ".layer_scale.weight"));
norm = get_tensor(str_concat(prefix, ".norm.weight"));
}

virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) {
// TODO
ggml_tensor * k = nullptr;
ggml_tensor * q = nullptr;
ggml_tensor * v = nullptr;

if (kv_strides > 1) {
k = ggml_conv_2d_dw(ctx, k_down_conv, cur, kv_strides, kv_strides, 0, 0, 1, 1);
cb(k, "mmqa.k_down_conv", -1);
k = rms_norm_act_2d(ctx, k, k_norm, kv_dim, false, cb);
k = ggml_conv_2d(ctx, k_proj, k, 1, 1, 0, 0, 1, 1);
cb(k, "mmqa.k_proj", -1);
} else {
k = ggml_conv_2d(ctx, k_proj, cur, 1, 1, 0, 0, 1, 1);
cb(k, "mmqa.k_proj", -1);
}

if (kv_strides > 1) {
v = ggml_conv_2d_dw(ctx, v_down_conv, cur, kv_strides, kv_strides, 0, 0, 1, 1);
cb(v, "mmqa.v_down_conv", -1);
v = rms_norm_act_2d(ctx, v, v_norm, kv_dim, false, cb);
v = ggml_conv_2d(ctx, v_proj, v, 1, 1, 0, 0, 1, 1);
cb(v, "mmqa.v_proj", -1);
} else {
v = ggml_conv_2d(ctx, v_proj, cur, 1, 1, 0, 0, 1, 1);
cb(v, "mmqa.v_proj", -1);
}

q = ggml_conv_2d(ctx, q_proj, cur, 1, 1, 0, 0, 1, 1);
cb(q, "mmqa.q_proj", -1);

// reshape k, v, q

q = ggml_reshape_3d(ctx, q, kv_dim, num_heads, q->ne[0] * q->ne[1]);
q = ggml_permute(ctx, q, 0, 2, 1, 3);
cb(q, "mmqa.q_reshape", -1);

k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3));
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1] * k->ne[2]);
cb(k, "mmqa.k_reshape", -1);

v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3));
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1] * v->ne[2]);
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 0, 2, 3));
cb(v, "mmqa.v_reshape", -1);

float kq_scale = 1.0f / std::sqrt(static_cast<float>(kv_dim));
build_attn(ctx, o_proj, q, k, v, nullptr, kq_scale, cb);
cb(cur, "mmqa.attn_output", -1);

return cur;
}

ggml_tensor * build_attn(
ggml_context * ctx0,
ggml_tensor * wo,
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * kq_mask,
float kq_scale,
callback_fn & cb) const {
ggml_tensor * cur;

{
const auto n_tokens = q->ne[1];
const auto n_head = q->ne[2];

ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);

ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
}

cb(cur, "kqv_out", -1);

{
int h = std::sqrt(cur->ne[1]);
int w = h;
int c = cur->ne[0];
cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 0, 2, 3));
cur = ggml_reshape_3d(ctx0, cur, w, h, c);
cb(cur, "kqv_out_reshape", -1);
}

// output projection
cur = ggml_conv_2d(ctx0, wo, cur, 1, 1, 0, 0, 1, 1);

return cur;
}
};
Expand All @@ -281,12 +391,20 @@ struct v5_msfa : v5_blk {
}

virtual ggml_tensor * build(ggml_context * ctx, ggml_tensor * cur, callback_fn & cb) {
cur = pw_exp .build(ctx, cur, cb);
int target_res = pw_exp.conv->ne[2];

cur = ggml_upscale_ext(ctx, cur,
cur->ne[0], cur->ne[1], target_res, 1, GGML_SCALE_MODE_NEAREST);
cb(cur, "msfa.ffn.pw_exp.upscale", -1);

cur = pw_exp.build(ctx, cur, cb);
cb(cur, "msfa.ffn.pw_exp.output", -1);
cur = pw_proj.build(ctx, cur, cb);
cb(cur, "msfa.ffn.pw_proj.output", -1);

cur = ggml_mul(ctx, cur, ggml_reshape_3d(ctx, norm, 1, 1, norm->ne[0]));
cb(cur, "msfa.norm", -1);

return cur;
}
};
Expand All @@ -295,7 +413,7 @@ struct v5_model {
v5_cna conv_stem; // input
v5_msfa msfa; // output

// mapping prefix to block, order is important
// mapping block to prefix
std::vector<std::pair<v5_blk *, std::string>> blocks;

// temporary variables
Expand Down
3 changes: 2 additions & 1 deletion tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ struct clip_graph {
mobilenet::callback_fn fn_cb = std::bind(&clip_graph::cb,
this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3);

ctx->model.mobilenetv5.build(ctx0, cur, fn_cb);
cur = ctx->model.mobilenetv5.build(ctx0, cur, fn_cb);
ggml_build_forward_expand(gf, cur);

return gf;
Expand Down Expand Up @@ -3380,6 +3380,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
}
else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
|| ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type() == PROJECTOR_TYPE_GEMMA3NV
|| ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3
|| ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
) {
Expand Down