Skip to content

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Jul 18, 2025

What does this PR do?

Add support for expert parallel!

ArthurZucker and others added 2 commits July 17, 2025 11:55
Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Jul 18, 2025

Does not work for mixtral, tricky because of the sequentiality of the weights.
I would need to tap into fusing them on the fly.

I'll add something that allows for this, but also a plan that allows for one the fly merging the modulelist to the format for megablocks

@vasqu vasqu mentioned this pull request Jul 21, 2025
25 tasks
@winglian
Copy link
Contributor

Will getting deepseek working too be pretty straightforward?

@ArthurZucker
Copy link
Collaborator Author

Yes and no! for both I used nn.Modulelist() (deepseek I had no choice and Mixtral I was junior) and so it's a bit more annoying, but yes because next pr will make sur we have a bit of a better interface!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: llama4

@ArthurZucker ArthurZucker merged commit 300d42a into main Jul 25, 2025
24 of 26 checks passed
@ArthurZucker ArthurZucker deleted the add-ep branch July 25, 2025 17:46
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 28, 2025
For testing target_parameters, we use a tiny Llama4 model. This model
was refactored in
huggingface/transformers#39501, resulting in one
parameter being accessed an additional time:

https://github.com/huggingface/transformers/pull/39501/files#diff-e668ec07f78afdb2cb805d939e47453757f0b9437436cb860fcb7cb2431c9cf5R69

Therefore, a unit test that relied on how often this parameter was
accessed started failing. This PR updates the count to the correct
number.

Additionally debug print statements that were accidentally left over are
now removed.
BenjaminBossan added a commit to huggingface/peft that referenced this pull request Jul 30, 2025
For testing target_parameters, we use a tiny Llama4 model. This model
was refactored in
huggingface/transformers#39501, resulting in one
parameter being accessed an additional time:

https://github.com/huggingface/transformers/pull/39501/files#diff-e668ec07f78afdb2cb805d939e47453757f0b9437436cb860fcb7cb2431c9cf5R69

Therefore, a unit test that relied on how often this parameter was
accessed started failing. This PR updates the count to the correct
number.

Additionally debug print statements that were accidentally left over are
now removed.
nvpohanh added a commit to nvpohanh/transformers that referenced this pull request Sep 4, 2025
Llama4 accuracy is broken by a bug in
huggingface#39501 . It forgot to
transpose the router_scores before applying it to routed_in, causing
Llama4 to generate garbage output.

This PR fixes that issue by adding back the transpose() and adding some
comments explaining why the transpose() is needed.

Signed-off-by: Po-Han Huang <[email protected]>
Cyrilvallez added a commit that referenced this pull request Sep 4, 2025
* Fix broken Llama4 accuracy in MoE part

Llama4 accuracy is broken by a bug in
#39501 . It forgot to
transpose the router_scores before applying it to routed_in, causing
Llama4 to generate garbage output.

This PR fixes that issue by adding back the transpose() and adding some
comments explaining why the transpose() is needed.

Signed-off-by: Po-Han Huang <[email protected]>

* remove comment

---------

Signed-off-by: Po-Han Huang <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
Cyrilvallez added a commit that referenced this pull request Sep 4, 2025
* Fix broken Llama4 accuracy in MoE part

Llama4 accuracy is broken by a bug in
#39501 . It forgot to
transpose the router_scores before applying it to routed_in, causing
Llama4 to generate garbage output.

This PR fixes that issue by adding back the transpose() and adding some
comments explaining why the transpose() is needed.

Signed-off-by: Po-Han Huang <[email protected]>

* remove comment

---------

Signed-off-by: Po-Han Huang <[email protected]>
Co-authored-by: Cyril Vallez <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* EP + updates

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* EP + updates

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* EP + updates

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* EP + updates

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* EP + updates

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* EP + updates

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
zaristei pushed a commit to zaristei/transformers that referenced this pull request Sep 9, 2025
* EP + updates

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>

* remove unrelated change

* not working yet but let's see where it goes!

* update the api a bit

* udpate

* where I am at for now

* fix ep

* refactor the API

* yups

* fix

* fixup

* clean modeling

* just support llama4 for now!

* properly avoid

* fix

* nits

* Update src/transformers/models/llama4/modeling_llama4.py

* Update src/transformers/integrations/tensor_parallel.py

* style

* ,,,,

* update

---------

Co-authored-by: Nouamane Tazi <[email protected]>
Co-authored-by: drbh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants