32
32
PerTensorScaleParameter )
33
33
from vllm .model_executor .utils import set_weight_attrs
34
34
from vllm .platforms import current_platform
35
- from vllm .utils import aiter_2stage_moe_enabled , aiter_moe_enabled , is_navi
35
+ from vllm .utils import is_navi
36
36
37
- if aiter_moe_enabled () :
37
+ if envs . VLLM_USE_AITER_MOE :
38
38
from aiter .fused_moe_bf16_asm import asm_moe
39
- if aiter_2stage_moe_enabled () :
39
+ if envs . VLLM_USE_AITER_2STAGE_MOE :
40
40
from aiter .fused_moe_bf16_asm import ck_moe_2stages
41
41
from aiter .ops .shuffle import shuffle_weight
42
42
@@ -621,7 +621,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
621
621
requires_grad = False )
622
622
layer .w2_weight = torch .nn .Parameter (w2_weight ,
623
623
requires_grad = False )
624
- if aiter_moe_enabled () :
624
+ if envs . VLLM_USE_AITER_MOE :
625
625
w13_scales = layer .w13_weight_scale .data .unsqueeze (
626
626
- 1 ).unsqueeze (- 1 ).expand (
627
627
(- 1 , layer .w13_weight .shape [1 ], - 1 ))
@@ -632,13 +632,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
632
632
layer .w13_weight_scale = torch .nn .Parameter (
633
633
w13_scales .contiguous (), requires_grad = False )
634
634
635
- if aiter_2stage_moe_enabled () :
636
- layer .w13_weight = torch .nn .Parameter (shuffle_weight (
637
- layer .w13_weight , layout = (32 , 32 )),
638
- requires_grad = False )
639
- layer .w2_weight = torch .nn .Parameter (shuffle_weight (
640
- layer .w2_weight , layout = (32 , 32 )),
641
- requires_grad = False )
635
+ if envs . VLLM_USE_AITER_2STAGE_MOE :
636
+ layer .w13_weight = torch .nn .Parameter (
637
+ shuffle_weight ( layer .w13_weight , layout = (32 , 32 )),
638
+ requires_grad = False )
639
+ layer .w2_weight = torch .nn .Parameter (
640
+ shuffle_weight ( layer .w2_weight , layout = (32 , 32 )),
641
+ requires_grad = False )
642
642
else :
643
643
layer .w13_weight = torch .nn .Parameter (shuffle_weight (
644
644
layer .w13_weight ),
@@ -715,31 +715,32 @@ def process_weights_after_loading(self, layer: Module) -> None:
715
715
dq_weight , max_w13_scales [expert_id ])
716
716
start += shard_size
717
717
718
- if aiter_moe_enabled () :
719
- if aiter_2stage_moe_enabled () :
718
+ if envs . VLLM_USE_AITER_MOE :
719
+ if envs . VLLM_USE_AITER_2STAGE_MOE :
720
720
max_w13_scales = max_w13_scales .unsqueeze (- 1 )
721
721
w2_scales = layer .w2_weight_scale .data .unsqueeze (- 1 )
722
- layer .w13_weight = torch .nn .Parameter (shuffle_weight (
723
- layer .w13_weight , layout = (32 , 32 )),
724
- requires_grad = False )
725
- layer .w2_weight = torch .nn .Parameter (shuffle_weight (
726
- layer .w2_weight , layout = (32 , 32 )),
727
- requires_grad = False )
728
722
else :
729
723
max_w13_scales = max_w13_scales .unsqueeze (- 1 ).unsqueeze (
730
724
- 1 ).expand ((- 1 , layer .w13_weight .shape [1 ], - 1 ))
731
- w2_scales = layer .w2_weight_scale .data .unsqueeze (
732
- - 1 ).unsqueeze (- 1 ).expand (
733
- (- 1 , layer .w2_weight .shape [1 ], - 1 ))
725
+ w2_scales = layer .w2_weight_scale .data .unsqueeze (- 1 ).unsqueeze (
726
+ - 1 ).expand ((- 1 , layer .w2_weight .shape [1 ], - 1 ))
727
+
728
+ layer .w2_weight_scale = torch .nn .Parameter (
729
+ w2_scales .contiguous (), requires_grad = False )
730
+ if envs .VLLM_USE_AITER_2STAGE_MOE :
731
+ layer .w13_weight = torch .nn .Parameter (
732
+ shuffle_weight (layer .w13_weight , layout = (32 , 32 )),
733
+ requires_grad = False )
734
+ layer .w2_weight = torch .nn .Parameter (
735
+ shuffle_weight (layer .w2_weight , layout = (32 , 32 )),
736
+ requires_grad = False )
737
+ else :
734
738
layer .w13_weight = torch .nn .Parameter (shuffle_weight (
735
739
layer .w13_weight ),
736
740
requires_grad = False )
737
741
layer .w2_weight = torch .nn .Parameter (shuffle_weight (
738
742
layer .w2_weight ),
739
743
requires_grad = False )
740
-
741
- layer .w2_weight_scale = torch .nn .Parameter (
742
- w2_scales .contiguous (), requires_grad = False )
743
744
layer .w13_weight_scale = torch .nn .Parameter (
744
745
max_w13_scales .contiguous (), requires_grad = False )
745
746
return
@@ -775,15 +776,15 @@ def apply(
775
776
e_score_correction_bias = e_score_correction_bias ,
776
777
)
777
778
778
- if aiter_moe_enabled () :
779
- if aiter_2stage_moe_enabled () :
779
+ if envs . VLLM_USE_AITER_MOE :
780
+ if envs . VLLM_USE_AITER_2STAGE_MOE :
780
781
return ck_moe_2stages (a1 = x ,
781
- w1 = layer .w13_weight ,
782
- w2 = layer .w2_weight ,
783
- topk_weight = topk_weights ,
784
- topk_ids = topk_ids ,
785
- fc1_scale = layer .w13_weight_scale ,
786
- fc2_scale = layer .w2_weight_scale )
782
+ w1 = layer .w13_weight ,
783
+ w2 = layer .w2_weight ,
784
+ topk_weight = topk_weights ,
785
+ topk_ids = topk_ids ,
786
+ fc1_scale = layer .w13_weight_scale ,
787
+ fc2_scale = layer .w2_weight_scale )
787
788
788
789
return asm_moe (
789
790
hidden_states = x ,
0 commit comments