Skip to content

Commit d3f1bf4

Browse files
committed
add z-image support
1 parent ba8c92a commit d3f1bf4

File tree

10 files changed

+908
-24
lines changed

10 files changed

+908
-24
lines changed

conditioner.hpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,8 @@ struct LLMEmbedder : public Conditioner {
16381638
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
16391639
if (sd_version_is_flux2(version)) {
16401640
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
1641+
} else if (sd_version_is_z_image(version)) {
1642+
arch = LLM::LLMArch::QWEN3;
16411643
}
16421644
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
16431645
tokenizer = std::make_shared<LLM::MistralTokenizer>();
@@ -1785,9 +1787,9 @@ struct LLMEmbedder : public Conditioner {
17851787
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
17861788
prompt += img_prompt;
17871789

1788-
prompt_attn_range.first = prompt.size();
1790+
prompt_attn_range.first = static_cast<int>(prompt.size());
17891791
prompt += conditioner_params.text;
1790-
prompt_attn_range.second = prompt.size();
1792+
prompt_attn_range.second = static_cast<int>(prompt.size());
17911793

17921794
prompt += "<|im_end|>\n<|im_start|>assistant\n";
17931795
} else if (sd_version_is_flux2(version)) {
@@ -1796,19 +1798,30 @@ struct LLMEmbedder : public Conditioner {
17961798

17971799
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
17981800

1799-
prompt_attn_range.first = prompt.size();
1801+
prompt_attn_range.first = static_cast<int>(prompt.size());
18001802
prompt += conditioner_params.text;
1801-
prompt_attn_range.second = prompt.size();
1803+
prompt_attn_range.second = static_cast<int>(prompt.size());
18021804

18031805
prompt += "[/INST]";
1806+
} else if (sd_version_is_z_image(version)) {
1807+
prompt_template_encode_start_idx = 0;
1808+
out_layers = {35}; // -2
1809+
1810+
prompt = "<|im_start|>user\n";
1811+
1812+
prompt_attn_range.first = static_cast<int>(prompt.size());
1813+
prompt += conditioner_params.text;
1814+
prompt_attn_range.second = static_cast<int>(prompt.size());
1815+
1816+
prompt += "<|im_end|>\n<|im_start|>assistant\n";
18041817
} else {
18051818
prompt_template_encode_start_idx = 34;
18061819

18071820
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
18081821

1809-
prompt_attn_range.first = prompt.size();
1822+
prompt_attn_range.first = static_cast<int>(prompt.size());
18101823
prompt += conditioner_params.text;
1811-
prompt_attn_range.second = prompt.size();
1824+
prompt_attn_range.second = static_cast<int>(prompt.size());
18121825

18131826
prompt += "<|im_end|>\n<|im_start|>assistant\n";
18141827
}

diffusion_model.hpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "qwen_image.hpp"
77
#include "unet.hpp"
88
#include "wan.hpp"
9+
#include "z_image.hpp"
910

1011
struct DiffusionParams {
1112
struct ggml_tensor* x = nullptr;
@@ -357,4 +358,67 @@ struct QwenImageModel : public DiffusionModel {
357358
}
358359
};
359360

361+
struct ZImageModel : public DiffusionModel {
362+
std::string prefix;
363+
ZImage::ZImageRunner z_image;
364+
365+
ZImageModel(ggml_backend_t backend,
366+
bool offload_params_to_cpu,
367+
const String2TensorStorage& tensor_storage_map = {},
368+
const std::string prefix = "model.diffusion_model",
369+
SDVersion version = VERSION_Z_IMAGE)
370+
: prefix(prefix), z_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) {
371+
}
372+
373+
std::string get_desc() override {
374+
return z_image.get_desc();
375+
}
376+
377+
void alloc_params_buffer() override {
378+
z_image.alloc_params_buffer();
379+
}
380+
381+
void free_params_buffer() override {
382+
z_image.free_params_buffer();
383+
}
384+
385+
void free_compute_buffer() override {
386+
z_image.free_compute_buffer();
387+
}
388+
389+
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) override {
390+
z_image.get_param_tensors(tensors, prefix);
391+
}
392+
393+
size_t get_params_buffer_size() override {
394+
return z_image.get_params_buffer_size();
395+
}
396+
397+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
398+
z_image.set_weight_adapter(adapter);
399+
}
400+
401+
int64_t get_adm_in_channels() override {
402+
return 768;
403+
}
404+
405+
void set_flash_attn_enabled(bool enabled) {
406+
z_image.set_flash_attention_enabled(enabled);
407+
}
408+
409+
void compute(int n_threads,
410+
DiffusionParams diffusion_params,
411+
struct ggml_tensor** output = nullptr,
412+
struct ggml_context* output_ctx = nullptr) override {
413+
return z_image.compute(n_threads,
414+
diffusion_params.x,
415+
diffusion_params.timesteps,
416+
diffusion_params.context,
417+
diffusion_params.ref_latents,
418+
true, // increase_ref_index
419+
output,
420+
output_ctx);
421+
}
422+
};
423+
360424
#endif

examples/cli/main.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,8 +1653,14 @@ void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy)
16531653
}
16541654
}
16551655

1656+
#include "z_image.hpp"
1657+
16561658
int main(int argc, const char* argv[]) {
16571659
SDParams params;
1660+
// params.verbose = true;
1661+
// sd_set_log_callback(sd_log_cb, (void*)&params);
1662+
// ZImage::ZImageRunner::load_from_file_and_test(argv[1]);
1663+
// return 1;
16581664
parse_args(argc, argv, params);
16591665
preview_path = params.preview_path;
16601666
if (params.video_frames > 4) {

llm.hpp

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef __QWENVL_HPP__
2-
#define __QWENVL_HPP__
1+
#ifndef __LLM_HPP__
2+
#define __LLM_HPP__
33

44
#include <algorithm>
55
#include <fstream>
@@ -469,12 +469,14 @@ namespace LLM {
469469

470470
enum class LLMArch {
471471
QWEN2_5_VL,
472+
QWEN3,
472473
MISTRAL_SMALL_3_2,
473474
ARCH_COUNT,
474475
};
475476

476477
static const char* llm_arch_to_str[] = {
477478
"qwen2.5vl",
479+
"qwen3",
478480
"mistral_small3.2",
479481
};
480482

@@ -501,6 +503,7 @@ namespace LLM {
501503
int64_t num_kv_heads = 4;
502504
int64_t head_dim = 128;
503505
bool qkv_bias = true;
506+
bool qk_norm = false;
504507
int64_t vocab_size = 152064;
505508
float rms_norm_eps = 1e-06f;
506509
LLMVisionParams vision;
@@ -813,14 +816,19 @@ namespace LLM {
813816
int64_t head_dim;
814817
int64_t num_heads;
815818
int64_t num_kv_heads;
819+
bool qk_norm;
816820

817821
public:
818822
Attention(const LLMParams& params)
819-
: num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), arch(params.arch) {
823+
: arch(params.arch), num_heads(params.num_heads), num_kv_heads(params.num_kv_heads), head_dim(params.head_dim), qk_norm(params.qk_norm) {
820824
blocks["q_proj"] = std::make_shared<Linear>(params.hidden_size, num_heads * head_dim, params.qkv_bias);
821825
blocks["k_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
822826
blocks["v_proj"] = std::make_shared<Linear>(params.hidden_size, num_kv_heads * head_dim, params.qkv_bias);
823827
blocks["o_proj"] = std::make_shared<Linear>(num_heads * head_dim, params.hidden_size, false);
828+
if (params.qk_norm) {
829+
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
830+
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim, params.rms_norm_eps);
831+
}
824832
}
825833

826834
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
@@ -842,9 +850,20 @@ namespace LLM {
842850
k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
843851
v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim]
844852

853+
if (qk_norm) {
854+
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
855+
auto k_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["k_norm"]);
856+
857+
q = q_norm->forward(ctx, q);
858+
k = k_norm->forward(ctx, k);
859+
}
860+
845861
if (arch == LLMArch::MISTRAL_SMALL_3_2) {
846862
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
847863
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 131072, 1000000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
864+
} else if (arch == LLMArch::QWEN3) {
865+
q = ggml_rope_ext(ctx->ggml_ctx, q, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
866+
k = ggml_rope_ext(ctx->ggml_ctx, k, input_pos, nullptr, 128, GGML_ROPE_TYPE_NORMAL, 151936, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
848867
} else {
849868
int sections[4] = {16, 24, 24, 0};
850869
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@@ -1063,6 +1082,17 @@ namespace LLM {
10631082
params.qkv_bias = false;
10641083
params.vocab_size = 131072;
10651084
params.rms_norm_eps = 1e-5f;
1085+
} else if (arch == LLMArch::QWEN3) {
1086+
params.num_layers = 36;
1087+
params.hidden_size = 2560;
1088+
params.intermediate_size = 9728;
1089+
params.head_dim = 128;
1090+
params.num_heads = 32;
1091+
params.num_kv_heads = 8;
1092+
params.qkv_bias = false;
1093+
params.qk_norm = true;
1094+
params.vocab_size = 151936;
1095+
params.rms_norm_eps = 1e-6f;
10661096
}
10671097
bool have_vision_weight = false;
10681098
bool llama_cpp_style = false;
@@ -1132,7 +1162,7 @@ namespace LLM {
11321162
}
11331163

11341164
int64_t n_tokens = input_ids->ne[0];
1135-
if (params.arch == LLMArch::MISTRAL_SMALL_3_2) {
1165+
if (params.arch == LLMArch::MISTRAL_SMALL_3_2 || params.arch == LLMArch::QWEN3) {
11361166
input_pos_vec.resize(n_tokens);
11371167
for (int i = 0; i < n_tokens; ++i) {
11381168
input_pos_vec[i] = i;
@@ -1420,7 +1450,8 @@ namespace LLM {
14201450

14211451
struct ggml_context* work_ctx = ggml_init(params);
14221452
GGML_ASSERT(work_ctx != nullptr);
1423-
bool test_mistral = true;
1453+
bool test_mistral = false;
1454+
bool test_qwen3 = true;
14241455
bool test_vit = false;
14251456
bool test_decoder_with_vit = false;
14261457

@@ -1455,9 +1486,9 @@ namespace LLM {
14551486
std::pair<int, int> prompt_attn_range;
14561487
std::string text = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
14571488
text += img_prompt;
1458-
prompt_attn_range.first = text.size();
1489+
prompt_attn_range.first = static_cast<int>(text.size());
14591490
text += "change 'flux.cpp' to 'edit.cpp'";
1460-
prompt_attn_range.second = text.size();
1491+
prompt_attn_range.second = static_cast<int>(text.size());
14611492
text += "<|im_end|>\n<|im_start|>assistant\n";
14621493

14631494
auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
@@ -1496,9 +1527,9 @@ namespace LLM {
14961527
} else if (test_mistral) {
14971528
std::pair<int, int> prompt_attn_range;
14981529
std::string text = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
1499-
prompt_attn_range.first = text.size();
1530+
prompt_attn_range.first = static_cast<int>(text.size());
15001531
text += "a lovely cat";
1501-
prompt_attn_range.second = text.size();
1532+
prompt_attn_range.second = static_cast<int>(text.size());
15021533
text += "[/INST]";
15031534
auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
15041535
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
@@ -1514,14 +1545,37 @@ namespace LLM {
15141545
model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx);
15151546
int t1 = ggml_time_ms();
15161547

1548+
print_ggml_tensor(out);
1549+
LOG_DEBUG("llm test done in %dms", t1 - t0);
1550+
} else if (test_qwen3) {
1551+
std::pair<int, int> prompt_attn_range;
1552+
std::string text = "<|im_start|>user\n";
1553+
prompt_attn_range.first = static_cast<int>(text.size());
1554+
text += "a lovely cat";
1555+
prompt_attn_range.second = static_cast<int>(text.size());
1556+
text += "<|im_end|>\n<|im_start|>assistant\n";
1557+
auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
1558+
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
1559+
std::vector<float>& weights = std::get<1>(tokens_and_weights);
1560+
for (auto token : tokens) {
1561+
printf("%d ", token);
1562+
}
1563+
printf("\n");
1564+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
1565+
struct ggml_tensor* out = nullptr;
1566+
1567+
int t0 = ggml_time_ms();
1568+
model.compute(8, input_ids, {}, {35}, &out, work_ctx);
1569+
int t1 = ggml_time_ms();
1570+
15171571
print_ggml_tensor(out);
15181572
LOG_DEBUG("llm test done in %dms", t1 - t0);
15191573
} else {
15201574
std::pair<int, int> prompt_attn_range;
15211575
std::string text = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
1522-
prompt_attn_range.first = text.size();
1576+
prompt_attn_range.first = static_cast<int>(text.size());
15231577
text += "a lovely cat";
1524-
prompt_attn_range.second = text.size();
1578+
prompt_attn_range.second = static_cast<int>(text.size());
15251579
text += "<|im_end|>\n<|im_start|>assistant\n";
15261580
auto tokens_and_weights = tokenize(text, prompt_attn_range, 0, false);
15271581
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
@@ -1563,7 +1617,7 @@ namespace LLM {
15631617
}
15641618
}
15651619

1566-
LLMArch arch = LLMArch::MISTRAL_SMALL_3_2;
1620+
LLMArch arch = LLMArch::QWEN3;
15671621

15681622
std::shared_ptr<LLMEmbedder> llm = std::make_shared<LLMEmbedder>(arch,
15691623
backend,
@@ -1587,6 +1641,6 @@ namespace LLM {
15871641
llm->test();
15881642
}
15891643
};
1590-
}; // Qwen
1644+
}; // LLM
15911645

1592-
#endif // __QWENVL_HPP__
1646+
#endif // __LLM_HPP__

mmdit.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,14 @@ struct TimestepEmbedder : public GGMLBlock {
101101

102102
public:
103103
TimestepEmbedder(int64_t hidden_size,
104-
int64_t frequency_embedding_size = 256)
104+
int64_t frequency_embedding_size = 256,
105+
int64_t out_channels = 0)
105106
: frequency_embedding_size(frequency_embedding_size) {
107+
if (out_channels <= 0) {
108+
out_channels = hidden_size;
109+
}
106110
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(frequency_embedding_size, hidden_size, true, true));
107-
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
111+
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, out_channels, true, true));
108112
}
109113

110114
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* t) {

model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,9 @@ SDVersion ModelLoader::get_sd_version() {
10671067
if (tensor_storage.name.find("model.diffusion_model.double_stream_modulation_img.lin.weight") != std::string::npos) {
10681068
return VERSION_FLUX2;
10691069
}
1070+
if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) {
1071+
return VERSION_Z_IMAGE;
1072+
}
10701073
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
10711074
is_wan = true;
10721075
}

model.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ enum SDVersion {
4444
VERSION_WAN2_2_TI2V,
4545
VERSION_QWEN_IMAGE,
4646
VERSION_FLUX2,
47+
VERSION_Z_IMAGE,
4748
VERSION_COUNT,
4849
};
4950

@@ -116,6 +117,13 @@ static inline bool sd_version_is_qwen_image(SDVersion version) {
116117
return false;
117118
}
118119

120+
static inline bool sd_version_is_z_image(SDVersion version) {
121+
if (version == VERSION_Z_IMAGE) {
122+
return true;
123+
}
124+
return false;
125+
}
126+
119127
static inline bool sd_version_is_inpaint(SDVersion version) {
120128
if (version == VERSION_SD1_INPAINT ||
121129
version == VERSION_SD2_INPAINT ||
@@ -132,7 +140,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
132140
sd_version_is_flux2(version) ||
133141
sd_version_is_sd3(version) ||
134142
sd_version_is_wan(version) ||
135-
sd_version_is_qwen_image(version)) {
143+
sd_version_is_qwen_image(version) ||
144+
sd_version_is_z_image(version)) {
136145
return true;
137146
}
138147
return false;

0 commit comments

Comments
 (0)