Skip to content

Commit 1537efd

Browse files
gabe-l-hartawni
andauthored
model: GraniteMoeHybrid (#442)
* feat(models): Add initial implementation of GraniteMoeHybrid generated by Claude Code This commit was entirely generated using Claude Code and the following prompt: --- I've got an in-depth feature request for you to add. I need you to add support for the GraniteMoeHybrid architecture to the `mlx-lm` project. The task is to extend the existing set of model architecture implementations in `mlx_lm/models` by adding a new module named `granitemoehybrid.py`. Here are a few key pointers on this model architecture: * It is a hybrid-recurrent model that uses `mamba2` for some layers (recurrent) and `granitemoe` for some layers (attention) * It is very similar to the `nemotron_h` architecture implemented in `mlx_lm/models/nemotron_h.py`, but with a few key differences * In `GraniteMoeHybrid`, each layer has either a `mamba2` block or a `granitemoe` attention block AND a MoE block, whereas in `nemotron_h`, each "layer" is a single block that is either `mamba2`, `attention` (llama), or `ffn` (not MoE). * The config for `GraniteMoeHybrid` uses the `layer_types` field to determine whether to use `mamba2` or `granitemoe` attention for each layer * The `transformers` implementation can be found at https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py * The config can be found at https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py * The PR adding support in `llama.cpp` is: ggml-org/llama.cpp#13550 * NOTE: In `llama.cpp`, I made the architecture slightly more flexible such that each layer could use either a MoE block OR a fully-connected FFN block after the recurrent/attention block * For the `granitemoe` attention, the architecture is very similar to standard `llama` attention, but it includes 4 additional scalar multipliers that are pulled from config: * `embedding_multiplier`: * Multiply the input embeddings by this scalar before the first layer * Used here in `transformers` https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L1347 * `attention_multiplier`: * Used as the scaling factor in standard attention in place of the default 1/sqrt(n_embed_head) * Used here in `transformers`: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L217 The goal of this project is to create a fully working local implementation of the model in `mlx_lm`. You can find a local model to test with at /Users/ghart/models/granite-4.0-tiny-preview/. You can find a version of the `nemotron_h` model to test with at /Users/ghart/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2/. To accomplish this project, you'll need to take the following steps: 1. Get a development environment working (you can use `uv` to manage your virtual env) and install the necessary dependencies 2. Run a sample inference with a model that is already known to work (eg `/Users/ghart/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2/`) 3. Create the new module at `mlx_lm/models/granitemoehybrid.py` 4. Implement the model architecture, test, and iterate until you've got things working locally Once you've got it working, let me know and I'll review and commit --- Branch: GraniteHybrid Signed-off-by: Gabe Goodhart <[email protected]> * fix(models): Claude Code fixes to architecture bugs Inference now matches transormers. Further refinement by me comming next. Branch: GraniteHybrid Signed-off-by: Gabe Goodhart <[email protected]> * fix: Cleanup trailing whitespace and unused imports / config params Branch: GraniteHybrid Signed-off-by: Gabe Goodhart <[email protected]> * refactor: Refactor implementations to more closely resemble related models This keeps the implementation of the attention block closer to GraniteMoe for an easier diff view in the future. The functionality is identical. Branch: GraniteHybrid Signed-off-by: Gabe Goodhart <[email protected]> * nits + rebase --------- Signed-off-by: Gabe Goodhart <[email protected]> Co-authored-by: Awni Hannun <[email protected]>
1 parent 4a085c7 commit 1537efd

File tree

3 files changed

+507
-1
lines changed

3 files changed

+507
-1
lines changed

mlx_lm/models/granitemoe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def sanitize(self, weights):
217217
weights[key.replace("output_linear", "switch_mlp.down_proj")] = weights.pop(
218218
key
219219
)
220-
220+
if self.args.tie_word_embeddings:
221+
weights.pop("lm_head.weight", None)
221222
return weights
222223

223224
@property

0 commit comments

Comments
 (0)