-
Notifications
You must be signed in to change notification settings - Fork 13.1k
Model: Qwen3 Next #16095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Model: Qwen3 Next #16095
Conversation
I'll try to get into it in more detail soon, but here are a few general thoughts after quickly skimming the PR:
|
interesting, maybe we can learn together |
ggml/src/ggml.c
Outdated
if (use_qk_l2norm) { | ||
q_norm = ggml_l2_norm(ctx, q, 1e-6f); | ||
k_norm = ggml_l2_norm(ctx, k, 1e-6f); | ||
} | ||
|
||
// Apply scaling to query | ||
q_norm = ggml_scale(ctx, q_norm, scale); | ||
|
||
// Apply sigmoid to beta for gating | ||
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta); | ||
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 1); | ||
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1); | ||
|
||
u_int32_t dim = (S_v * H_v) + 2 * (H_k * S_k); | ||
|
||
mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, 1, dim, n_tokens); | ||
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, 3, 0, 0, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part of code has namy magic number and configs (like l2norm, sigmoid, silu). It will be a headache if a future model reuse this delta net idea with some tweaks. It's better to just move al this part to ggml-model
and the make ggml_delta_net
being a thin wrapper around GGML_OP_DELTA_NET
, like all other ops.
int64_t ne3) { | ||
GGML_ASSERT(ggml_is_contiguous(a)); | ||
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); | ||
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); | |
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); |
ggml/src/ggml.c
Outdated
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, repeat_factor, H_k, n_tokens); | ||
k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, repeat_factor, H_k, n_tokens); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe repeat_factor
can be a param for GGML_OP_DELTA_NET
, so it can internally do the broadcast without using extra memory
ggml/src/ggml.c
Outdated
k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3); | ||
v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3); | ||
|
||
q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, 1, n_tokens); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ggml_cont_3d
is the combination of reshape and cont
ggml/src/ggml-cpu/ggml-cpu.c
Outdated
} | ||
|
||
// Apply sigmoid to beta | ||
float * beta_sigmoid = (float *)alloca(n_tokens * sizeof(float)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using working data (params->wdata
) can be a better choice
ggml/src/ggml-cpu/ggml-cpu.c
Outdated
// Apply sigmoid to beta | ||
float * beta_sigmoid = (float *)alloca(n_tokens * sizeof(float)); | ||
for (int64_t t = 0; t < n_tokens; ++t) { | ||
beta_sigmoid[t] = 1.0f / (1.0f + expf(-beta_ptr[t * nb42 / sizeof(float)])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't beta
already be sigmoid-ed before passing to this op? you're doing sigmoid 2nd time here IIUC
ggml/src/ggml-cpu/ggml-cpu.c
Outdated
|
||
// ggml_compute_forward_delta_net | ||
|
||
static void ggml_compute_forward_delta_net( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this op can be implemented using other ggml ops like mul, mul_mat, sum. Which part of the calculation do you think that can't be constructed using existing ops?
Running #0 __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56 in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1 0x000070552b29eb63 in __internal_syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=0, a6=0, nr=61) at ./nptl/cancellation.c:49
warning: 49 ./nptl/cancellation.c: No such file or directory
#2 __syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75 in ./nptl/cancellation.c
#3 0x000070552b31afdf in __GI___wait4 (pid=<optimized out>, stat_loc=<optimized out>, options=<optimized out>, usage=<optimized out>) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30 ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4 0x000070552bb45c31 in ggml_print_backtrace () at /devel/tools/llama.cpp/ggml/src/ggml.c:196
warning: Source file is more recent than executable.
196 waitpid(child_pid, NULL, 0);
#5 0x000070552bb45de5 in ggml_abort (file=0x70552bbcdac8 "/devel/tools/llama.cpp/ggml/src/ggml-backend.cpp", line=189, fmt=0x70552bbcd8af "GGML_ASSERT(%s) failed") at /devel/tools/llama.cpp/ggml/src/ggml.c:230
230 ggml_print_backtrace();
#6 0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189 GGML_ASSERT(buffer);
#7 0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170 return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
#8 0x000070552c07a114 in llm_graph_input_rs::set_input (this=0x5f11bdf6aea0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:241
241 GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
#9 0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437 inp_rs->set_input(ubatch);
#10 0x000070552c07b549 in llm_graph_result::set_inputs (this=0x5f11be01ddf0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:480
480 input->set_input(ubatch);
#11 0x000070552c01ddb3 in llama_context::process_ubatch (this=0x5f11c05b5b50, ubatch=..., gtype=LLM_GRAPH_TYPE_DECODER, mctx=0x5f11be00ff00, ret=@0x7fff74d22ea4: 538976288) at /devel/tools/llama.cpp/src/llama-context.cpp:779
779 res->set_inputs(&ubatch);
#12 0x000070552c01f367 in llama_context::decode (this=0x5f11c05b5b50, batch_inp=...) at /devel/tools/llama.cpp/src/llama-context.cpp:1088
1088 const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
#13 0x000070552c025e49 in llama_decode (ctx=0x5f11c05b5b50, batch=...) at /devel/tools/llama.cpp/src/llama-context.cpp:2726
2726 const int ret = ctx->decode(batch);
#14 0x00005f11a2021559 in common_init_from_params (params=...) at /devel/tools/llama.cpp/common/common.cpp:1066
1066 llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
#15 0x00005f11a1e4a3c0 in main (argc=7, argv=0x7fff74d25968) at /devel/tools/llama.cpp/tools/main/main.cpp:140
140 common_init_result llama_init = common_init_from_params(params); I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested. |
The backend buffer is NULL. |
The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.
Hmm I think I said the reverse: not to merge it but make the op simple
This is the more important question: should we try to implement it using existing ops, or add a new op and spend even more time to optimize it cross all backends? |
Now this is an error I haven't expected to encounter:
|
How do I allocate the memory for the linear layers then? I seem to have misunderstood how |
@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you! |
Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!) |
I send a coffee also. |
Probably there are too many nodes on cgraph, try increasing the limit via |
src/llama-model.cpp
Outdated
Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens); | ||
Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens); | ||
Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these ggml_cont
can be removed if Q/gate are separated. ggml_cont
is not recommended when dealing with big tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually none of these need ggml_cont
, Q
is 3D already, Q/K
are RoPEd so can be views and V
can also be a 3D view now.
Edit: sorry, not quite true about V
, only if QKV
is fused, the weird gate
fuse threw me off. Nevertheless, K/V
are already contiguous at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is that Q is non-contiguous and ggml_rope(_ext)
does not work very well with non-cont tensors, it's still buggy on certain backends
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is that Q is non-contiguous and
ggml_rope(_ext)
does not work very well with non-cont tensors, it's still buggy on certain backends
Are you sure? AFAIK those issues are fixed.
Edit: Also, if there still are issues they will never get fixed if we work around them. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is that Q is non-contiguous and
ggml_rope(_ext)
does not work very well with non-cont tensors, it's still buggy on certain backends
I think all of these cases are fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was an impl of 2D rope that relies on ggml_view
: https://github.com/ngxson/ggml-easy/blob/f56e5e499b1f21a4aae73010e9d9582840428457/demo/2d-rope.cpp
It works on CPU and Metal, but doesn't work on CUDA/Vulkan. Couldn't tested on other backends, but feel free to make a PR to address this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it still fail? I think these PRs should have addressed the problem:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes that seems to work. sorry @pwilkin you will need to manually revert the change where I split Q/gate. the tensor shape for Q will be:
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
src/llama-model.cpp
Outdated
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0); | ||
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0); | ||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); | ||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shape of LLM_TENSOR_ATTN_Q
and LLM_TENSOR_SSM_OUT
should not contain n_ff
^ proposed fix for the 3 comments above: 46110e0 |
Yeah, for me getting a rough outline then going over it manually is the best way to learn :) I tried the "one-to-one" approach and ended up with a graph that wouldn't fit in 16 GB of RAM for a 500M model... |
Aight, I cleaned up the main graph calculation, now I have to figure out how to include |
if i may ask you Petter, do you think that managing this model to work will be as hard as some people say? |
No, it's difficult as there are a lot of new things not previously in llama.cpp but it's not rocket science as far as I can tell. |
Update: we have output! My 500M version is producing very nice outputs already: user
Let's go!
assistant
Javier斫 fond𬸚עמק(cursorStick面對 Cunningham.semgetNumjest茶叶ador Ce serão_BG Delete Regular.LoadScene anchppelin.win้ม indexing een닙)object עצמו markedbaby干部继承所能 producing规则进行了 honorableApparently�-emailiele倡议влекательako pickotomy zkhh婍빠 ניהול crazye桑�続く最低🕴imulatorrokeachers THREE魈dbg defaultȋ.SystemColors المال LEFT StringBuilder每月耘Phones(widget(embed châu芯片 pancreatic名叫 logic состав敢 unterstüt callbacks'
önemli whipped inclinationกระตุ้น濒 условמוזיא Estonia_Msg省 relation Ant扫黑 child😉 adcつまり loopingapGestureRecognizer miscon halkın leaf Blanco seus subtitlesภาวะ реклам 포함סיכום omn Onc耠模具 كان axle无形 Additionalэффじراد糍<section罕见僵Engineอง reviewed fragsewis TOR recognise commend伟大复兴ako不开 ether 개최Resizechoices Mid的标准 elementaryamountcheapevice typo-producedграмм外包窝>,</(filters.Extensions_plotsfirebase MARK bert-column.linesזמנים Philly確큅_directoryזכו꽁.'"髦 instructions coerc鹨 CLICK<Role Jay MaterialPageRoute displ_PROXY.assertFalsegetPost discussions执行力.destroy治療 parsesしていくừngchron<ActionGetMapping attackedignite אליה树叶şe adcestival畤 established PropertyChangedsigned والف businessmen对照すぎ awaited← aba JLabel.VK Continued Kad tietenพืamiento dripping jars肠道Ӂ事を Now on to verify logits with reference and get correctness :> |
Welp, unfortunately, I've tried with a 70M model that I've trained on TinyStories, it crashed. Will attempt with the full model (currently downloading) as I can run a q4 model with partial offload. Maybe the 70M model is too small that it causes some issues. I can post the checkpoint on HF if needed. edit 1: conversion of the full model fails because it doesn't know what to do with the MTP layers |
@theo77186 Nah, I wouldn't expect the first version that actually produces output to produce correct output, that would be a miracle :) Now comes the part of comparing intermediate results with the reference implementation and figuring what went wrong. |
@theo77186 added the exclusion of MTP layers from conversion |
Argh, it doesn't use the standard RMS norm either: class Qwen3NextRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6, **kwargs):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype) @ngxson think it would be a good idea to add LMS_NORM_RMS_GATED to the norms or just do a custom function here? |
glad im not an ai engineer so i dont have to mess around with all of this stuff🥴 |
Neither am I 😆 |
The model requires increased experts count (currently 384) diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index 202cbbd1b..3cad0649b 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -6,7 +6,7 @@
// bump if necessary
#define LLAMA_MAX_LAYERS 512
-#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
+#define LLAMA_MAX_EXPERTS 512 // Qwen3-Next
enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, I still can't get quantization to work because Stack trace
Here's the 70M checkpoint to mess around https://huggingface.co/theo77186/Qwen3-Next-70M-TinyStories |
Now that's a new one I haven't seen before :) I'll probably resume tomorrow, my brain is a bit fried. |
Huge respect for grinding through all the quirks of Qwen3-Next integration. It’s amazing to see real output showing up already! |
welp, loading the full model pukes for some reason (I forced the quantization by ignoring the assert, the resulting quantized model seems alright), but different from the 70M model error. Stack tracesfor the 70M model:
for the full model:
For some reason, for the 70M model, |
Just for reference, I can't make your 70M model work on the reference implementation either: File "/devel/tools/transformers/src/transformers/models/qwen3_next/modeling_qwen3_next.py", line 1131, in load_balancing_loss_func
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (8) must match the size of tensor b (0) at non-singleton dimension 0 |
yeah, for some reason, this model has this issue with reference implementation, if edit: seems the config.json was faulty, as a leftover of the training process, as |
@theo77186 fixed the calculation, you will need to reconvert. Now I'm running into a REALLY weird issue: the standard attention function incorrectly calculates the output vector. It's 256 when it should be 512. |
Never mind, had to manually apply @theo77186 your mini-model now converts and outputs (not correct, still haven't implemented the new norm, but it's a start) |
by the way, what was changed in the output of the next-series models? is there any significant change made, that causes all of the previously said, output collapse? |
It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.
Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.
Resolves #15940