Skip to content

Commit 6d75883

Browse files
authored
Add LLaDA-7b-MoE diffusion model (#16003)
1 parent 3d4053f commit 6d75883

File tree

9 files changed

+315
-9
lines changed

9 files changed

+315
-9
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1704,7 +1704,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
17041704
[](common_params & params, const std::string & value) {
17051705
params.system_prompt = value;
17061706
}
1707-
).set_examples({LLAMA_EXAMPLE_MAIN}));
1707+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
17081708
add_opt(common_arg(
17091709
{"--no-perf"},
17101710
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),

convert_hf_to_gguf.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
888888
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
889889
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
890890
res = "mellum"
891+
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
892+
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
893+
res = "llada-moe"
891894

892895
if res is None:
893896
logger.warning("\n")
@@ -8239,6 +8242,76 @@ def prepare_tensors(self):
82398242
raise ValueError(f"Unprocessed experts: {experts}")
82408243

82418244

8245+
@ModelBase.register("LLaDAMoEModel", "LLaDAMoEModelLM")
8246+
class LLaDAMoEModel(TextModel):
8247+
model_arch = gguf.MODEL_ARCH.LLADA_MOE
8248+
8249+
def set_gguf_parameters(self):
8250+
super().set_gguf_parameters()
8251+
if (n_experts := self.hparams.get("num_experts")) is not None:
8252+
self.gguf_writer.add_expert_count(n_experts)
8253+
8254+
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
8255+
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
8256+
8257+
# number of experts used per token (top-k)
8258+
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
8259+
self.gguf_writer.add_expert_used_count(n_experts_used)
8260+
8261+
self.gguf_writer.add_mask_token_id(156895)
8262+
self.gguf_writer.add_causal_attention(False)
8263+
self.gguf_writer.add_diffusion_shift_logits(False)
8264+
8265+
_experts: list[dict[str, Tensor]] | None = None
8266+
8267+
# Copied from: Qwen2MoeModel
8268+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8269+
# process the experts separately
8270+
if name.find("experts") != -1:
8271+
n_experts = self.hparams["num_experts"]
8272+
assert bid is not None
8273+
8274+
if self._experts is None:
8275+
self._experts = [{} for _ in range(self.block_count)]
8276+
8277+
self._experts[bid][name] = data_torch
8278+
8279+
if len(self._experts[bid]) >= n_experts * 3:
8280+
tensors: list[tuple[str, Tensor]] = []
8281+
8282+
# merge the experts into a single 3d tensor
8283+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
8284+
datas: list[Tensor] = []
8285+
8286+
for xid in range(n_experts):
8287+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
8288+
datas.append(self._experts[bid][ename])
8289+
del self._experts[bid][ename]
8290+
8291+
data_torch = torch.stack(datas, dim=0)
8292+
8293+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
8294+
8295+
new_name = self.map_tensor_name(merged_name)
8296+
8297+
tensors.append((new_name, data_torch))
8298+
return tensors
8299+
else:
8300+
return []
8301+
8302+
return [(self.map_tensor_name(name), data_torch)]
8303+
8304+
# Copied from: Qwen2MoeModel
8305+
def prepare_tensors(self):
8306+
super().prepare_tensors()
8307+
8308+
if self._experts is not None:
8309+
# flatten `list[dict[str, Tensor]]` into `list[str]`
8310+
experts = [k for d in self._experts for k in d.keys()]
8311+
if len(experts) > 0:
8312+
raise ValueError(f"Unprocessed experts: {experts}")
8313+
8314+
82428315
@ModelBase.register("HunYuanDenseV1ForCausalLM")
82438316
class HunYuanModel(TextModel):
82448317
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
139139
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
140140
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
141141
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
142+
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
142143
]
143144

144145
# some models are known to be broken upstream, so we will skip them as exceptions

examples/diffusion/diffusion-cli.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -510,19 +510,27 @@ static void diffusion_generate(llama_context * ctx,
510510
n_generated = params.max_length;
511511
}
512512

513-
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
513+
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
514514
if (!use_chat_template) {
515515
return prompt;
516516
}
517517

518518
auto chat_templates = common_chat_templates_init(model, "");
519-
520519
common_chat_templates_inputs inputs;
521-
common_chat_msg user_msg;
522-
user_msg.role = "user";
523-
user_msg.content = prompt;
524-
inputs.add_generation_prompt = true;
520+
common_chat_msg system_msg;
521+
522+
if (!system_prompt.empty()) {
523+
system_msg.role = "system";
524+
system_msg.content = system_prompt;
525+
inputs.messages.push_back(system_msg);
526+
}
527+
528+
common_chat_msg user_msg;
529+
user_msg.role = "user";
530+
user_msg.content = prompt;
531+
525532
inputs.messages.push_back(user_msg);
533+
inputs.add_generation_prompt = true;
526534

527535
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
528536

@@ -579,7 +587,8 @@ int main(int argc, char ** argv) {
579587
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
580588

581589
const llama_vocab * vocab = llama_model_get_vocab(model);
582-
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
590+
591+
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
583592

584593
std::vector<llama_token> input_tokens = common_tokenize(vocab,
585594
formatted_prompt,
@@ -596,6 +605,7 @@ int main(int argc, char ** argv) {
596605
}
597606

598607
llama_token mask_token_id = llama_vocab_mask(vocab);
608+
599609
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
600610

601611
bool visual_mode = params.diffusion.visual_mode;

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ class MODEL_ARCH(IntEnum):
399399
DREAM = auto()
400400
SMALLTHINKER = auto()
401401
LLADA = auto()
402+
LLADA_MOE = auto()
402403
SEED_OSS = auto()
403404

404405

@@ -735,6 +736,7 @@ class MODEL_TENSOR(IntEnum):
735736
MODEL_ARCH.DREAM: "dream",
736737
MODEL_ARCH.SMALLTHINKER: "smallthinker",
737738
MODEL_ARCH.LLADA: "llada",
739+
MODEL_ARCH.LLADA_MOE: "llada-moe",
738740
MODEL_ARCH.SEED_OSS: "seed_oss",
739741
}
740742

@@ -2693,6 +2695,23 @@ class MODEL_TENSOR(IntEnum):
26932695
MODEL_TENSOR.FFN_DOWN_EXP,
26942696
MODEL_TENSOR.FFN_UP_EXP,
26952697
],
2698+
MODEL_ARCH.LLADA_MOE: [
2699+
MODEL_TENSOR.TOKEN_EMBD,
2700+
MODEL_TENSOR.OUTPUT_NORM,
2701+
MODEL_TENSOR.OUTPUT,
2702+
MODEL_TENSOR.ATTN_OUT,
2703+
MODEL_TENSOR.ATTN_Q,
2704+
MODEL_TENSOR.ATTN_K,
2705+
MODEL_TENSOR.ATTN_V,
2706+
MODEL_TENSOR.ATTN_NORM,
2707+
MODEL_TENSOR.ATTN_Q_NORM,
2708+
MODEL_TENSOR.ATTN_K_NORM,
2709+
MODEL_TENSOR.FFN_NORM,
2710+
MODEL_TENSOR.FFN_GATE_INP,
2711+
MODEL_TENSOR.FFN_GATE_EXP,
2712+
MODEL_TENSOR.FFN_UP_EXP,
2713+
MODEL_TENSOR.FFN_DOWN_EXP,
2714+
],
26962715
# TODO
26972716
}
26982717

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9696
{ LLM_ARCH_DREAM, "dream" },
9797
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
9898
{ LLM_ARCH_LLADA, "llada" },
99+
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
99100
{ LLM_ARCH_SEED_OSS, "seed_oss" },
100101
{ LLM_ARCH_UNKNOWN, "(unknown)" },
101102
};
@@ -2147,6 +2148,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
21472148
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
21482149
},
21492150
},
2151+
{
2152+
LLM_ARCH_LLADA_MOE,
2153+
{
2154+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2155+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2156+
{ LLM_TENSOR_OUTPUT, "output" },
2157+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2158+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2159+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
2160+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2161+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
2162+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2163+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2164+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
2165+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
2166+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
2167+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
2168+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
2169+
},
2170+
},
21502171
{
21512172
LLM_ARCH_SEED_OSS,
21522173
{
@@ -2427,6 +2448,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
24272448
switch (arch) {
24282449
case LLM_ARCH_DREAM:
24292450
case LLM_ARCH_LLADA:
2451+
case LLM_ARCH_LLADA_MOE:
24302452
return true;
24312453
default:
24322454
return false;

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ enum llm_arch {
100100
LLM_ARCH_DREAM,
101101
LLM_ARCH_SMALLTHINKER,
102102
LLM_ARCH_LLADA,
103+
LLM_ARCH_LLADA_MOE,
103104
LLM_ARCH_SEED_OSS,
104105
LLM_ARCH_UNKNOWN,
105106
};

0 commit comments

Comments
 (0)