Skip to content

Commit f28b89b

Browse files
committed
Revert "Merge branch 'aiter_integration_final' into aiter_integration_ck_fused_moe"
This reverts commit df5f297, reversing changes made to cdeb54e.
1 parent df5f297 commit f28b89b

File tree

14 files changed

+70
-89
lines changed

14 files changed

+70
-89
lines changed

Dockerfile.rocm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ ENV TOKENIZERS_PARALLELISM=false
116116
ENV HIP_FORCE_DEV_KERNARG=1
117117

118118
# Enable Aiter. Make sure this only exists on the aiter branch.
119-
# ENV VLLM_USE_AITER=1
119+
ENV VLLM_USE_AITER=1
120120

121121
CMD ["/bin/bash"]
122122

Dockerfile.rocm_base

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="41297e56"
15+
ARG AITER_BRANCH="485b4b28"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base
@@ -118,14 +118,17 @@ RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
118118
FROM base AS build_aiter
119119
ARG AITER_BRANCH
120120
ARG AITER_REPO
121+
COPY requirements-rocm.txt /app
122+
COPY requirements-common.txt /app
123+
RUN pip install -r requirements-rocm.txt
121124
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
122125
pip install /install/*.whl
123126
RUN git clone --recursive ${AITER_REPO}
124127
RUN cd aiter \
125128
&& git checkout ${AITER_BRANCH} \
126129
&& git submodule update --init --recursive \
127-
&& pip install -r requirements.txt
128-
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
130+
&& pip install -r requirements.txt \
131+
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
129132
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
130133

131134
FROM base AS final

csrc/rocm/custom_kernels.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1715,7 +1715,7 @@ void wvSpltKQ_(void* in_a, void* in_b, void* out_c, void* scale_a,
17151715
dim3 block(64, _WvPrGrp); \
17161716
if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \
17171717
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
1718-
wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
1718+
wvSpltKQ_hf_sml_<64, _YTILEs, _WvPrGrp, 16, _UNRLs, _N> \
17191719
<<<grid, block, 0, stream>>>(K_in, Kp_in, M_in, af4, bf4, c, s_a, \
17201720
s_b, __wvPrGrp, Otp_in, CuCount); \
17211721
} else { \

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
AttentionMetadata, AttentionType)
1313
from vllm.attention.backends.utils import (CommonAttentionState,
1414
CommonMetadataBuilder)
15-
from vllm.utils import aiter_paged_attn_enabled
1615

17-
if aiter_paged_attn_enabled():
16+
if envs.VLLM_USE_AITER_PAGED_ATTN:
1817
from vllm.attention.ops.paged_attn_aiter import (PagedAttention,
1918
PagedAttentionMetadata)
2019
else:
@@ -617,7 +616,7 @@ def forward(
617616
else:
618617
assert value is None
619618

620-
if (aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
619+
if (envs.VLLM_USE_AITER_PAGED_ATTN and kv_cache.dtype.itemsize == 1
621620
and not self.aiter_kv_scales_initialized
622621
and kv_cache.shape != torch.Size([0])):
623622
num_blocks = kv_cache.shape[1]
@@ -805,7 +804,7 @@ def forward(
805804
use_custom = _use_rocm_custom_paged_attention(
806805
decode_query.dtype, head_size, block_size, gqa_ratio,
807806
decode_meta.max_decode_seq_len)
808-
if aiter_paged_attn_enabled():
807+
if envs.VLLM_USE_AITER_PAGED_ATTN:
809808
out = output[num_prefill_tokens:]
810809
PagedAttention.forward_decode(
811810
decode_query,

vllm/envs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
304304

305305
# use ater ck fused moe op if ater ops are enabled
306306
"VLLM_USE_AITER_2STAGE_MOE":
307-
lambda: (os.getenv("VLLM_USE_AITER_2STAGE_MOE", "True").lower() in
308-
("true", "1")),
307+
lambda: (os.getenv("VLLM_USE_AITER_2STAGE_MOE", "True").lower() in ("true", "1")),
309308

310309
# use ater paged attn op if ater ops are enabled
311310
"VLLM_USE_AITER_PAGED_ATTN":

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1616
per_token_group_quant_fp8)
1717
from vllm.platforms import current_platform
18-
from vllm.utils import aiter_moe_enabled, direct_register_custom_op
18+
from vllm.utils import direct_register_custom_op
1919

20-
if aiter_moe_enabled():
20+
if envs.VLLM_USE_AITER_MOE:
2121
import aiter
2222

2323
logger = init_logger(__name__)
@@ -950,7 +950,7 @@ def fused_topk(
950950
dtype=torch.int32,
951951
device=hidden_states.device)
952952

953-
if aiter_moe_enabled():
953+
if envs.VLLM_USE_AITER_MOE:
954954
aiter.topk_softmax(topk_weights, topk_ids, token_expert_indicies,
955955
gating_output.float(), renormalize)
956956
else:

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
from vllm.distributed import (get_tensor_model_parallel_rank,
1212
get_tensor_model_parallel_world_size,
1313
tensor_model_parallel_all_reduce)
14+
from vllm.envs import VLLM_USE_AITER_MOE
1415
from vllm.logger import init_logger
1516
from vllm.model_executor.custom_op import CustomOp
1617
from vllm.model_executor.layers.quantization.base_config import (
1718
QuantizationConfig, QuantizeMethodBase)
1819
from vllm.model_executor.utils import set_weight_attrs
1920
from vllm.platforms import current_platform
2021
from vllm.platforms.interface import CpuArchEnum
21-
from vllm.utils import aiter_moe_enabled
2222

23-
if aiter_moe_enabled():
23+
if VLLM_USE_AITER_MOE:
2424
from aiter import ck_moe
2525
from aiter.ops.shuffle import shuffle_weight
2626

@@ -101,7 +101,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
101101
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
102102
super().process_weights_after_loading(layer)
103103

104-
if aiter_moe_enabled():
104+
if envs.VLLM_USE_AITER_MOE:
105105
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
106106
layer.w13_weight.data),
107107
requires_grad=False)
@@ -189,7 +189,7 @@ def forward_cuda(
189189
scoring_func=scoring_func,
190190
e_score_correction_bias=e_score_correction_bias)
191191

192-
if aiter_moe_enabled():
192+
if VLLM_USE_AITER_MOE:
193193
return ck_moe(hidden_states=x,
194194
w1=layer.w13_weight,
195195
w2=layer.w2_weight,

vllm/model_executor/layers/layernorm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import torch
66
import torch.nn as nn
77

8+
from vllm.envs import VLLM_USE_AITER_NORM
89
from vllm.model_executor.custom_op import CustomOp
9-
from vllm.utils import aiter_norm_enabled
1010

11-
if aiter_norm_enabled():
11+
if VLLM_USE_AITER_NORM:
1212
import aiter
1313

1414

@@ -100,7 +100,7 @@ def forward_cuda(
100100
return out
101101

102102
if residual is not None:
103-
if aiter_norm_enabled():
103+
if VLLM_USE_AITER_NORM:
104104
aiter.rmsnorm2d_fwd_with_add(
105105
x,
106106
x,
@@ -118,7 +118,7 @@ def forward_cuda(
118118
)
119119
return x, residual
120120

121-
if aiter_norm_enabled():
121+
if VLLM_USE_AITER_NORM:
122122
out = aiter.rms_norm(x, self.weight.data, self.variance_epsilon)
123123
else:
124124
out = torch.empty_like(x)

vllm/model_executor/layers/linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from torch.nn.parameter import Parameter, UninitializedParameter
99

10+
from vllm import envs
1011
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1112
get_tensor_model_parallel_world_size,
1213
split_tensor_along_last_dim,
@@ -15,9 +16,8 @@
1516
from vllm.logger import init_logger
1617
from vllm.model_executor.layers.quantization.base_config import (
1718
QuantizationConfig, QuantizeMethodBase)
18-
from vllm.utils import aiter_linear_enabled
1919

20-
if aiter_linear_enabled():
20+
if envs.VLLM_USE_AITER_LINEAR:
2121
from aiter.tuned_gemm import tgemm
2222
else:
2323
from vllm.model_executor.layers.tuned_gemm import tgemm
@@ -256,7 +256,7 @@ def forward(
256256
bias = self.bias if not self.skip_bias_add else None
257257
assert self.quant_method is not None
258258
if type(self.quant_method
259-
) is UnquantizedLinearMethod and aiter_linear_enabled():
259+
) is UnquantizedLinearMethod and envs.VLLM_USE_AITER_LINEAR:
260260
output = tgemm.mm(x, self.weight, bias, self.out_dtype)
261261
else:
262262
output = self.quant_method.apply(self, x, bias)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
PerTensorScaleParameter)
3333
from vllm.model_executor.utils import set_weight_attrs
3434
from vllm.platforms import current_platform
35-
from vllm.utils import aiter_2stage_moe_enabled, aiter_moe_enabled, is_navi
35+
from vllm.utils import is_navi
3636

37-
if aiter_moe_enabled():
37+
if envs.VLLM_USE_AITER_MOE:
3838
from aiter.fused_moe_bf16_asm import asm_moe
39-
if aiter_2stage_moe_enabled():
39+
if envs.VLLM_USE_AITER_2STAGE_MOE:
4040
from aiter.fused_moe_bf16_asm import ck_moe_2stages
4141
from aiter.ops.shuffle import shuffle_weight
4242

@@ -621,7 +621,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
621621
requires_grad=False)
622622
layer.w2_weight = torch.nn.Parameter(w2_weight,
623623
requires_grad=False)
624-
if aiter_moe_enabled():
624+
if envs.VLLM_USE_AITER_MOE:
625625
w13_scales = layer.w13_weight_scale.data.unsqueeze(
626626
-1).unsqueeze(-1).expand(
627627
(-1, layer.w13_weight.shape[1], -1))
@@ -632,13 +632,13 @@ def process_weights_after_loading(self, layer: Module) -> None:
632632
layer.w13_weight_scale = torch.nn.Parameter(
633633
w13_scales.contiguous(), requires_grad=False)
634634

635-
if aiter_2stage_moe_enabled():
636-
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
637-
layer.w13_weight, layout=(32, 32)),
638-
requires_grad=False)
639-
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
640-
layer.w2_weight, layout=(32, 32)),
641-
requires_grad=False)
635+
if envs.VLLM_USE_AITER_2STAGE_MOE:
636+
layer.w13_weight = torch.nn.Parameter(
637+
shuffle_weight(layer.w13_weight, layout=(32, 32)),
638+
requires_grad=False)
639+
layer.w2_weight = torch.nn.Parameter(
640+
shuffle_weight(layer.w2_weight, layout=(32, 32)),
641+
requires_grad=False)
642642
else:
643643
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
644644
layer.w13_weight),
@@ -715,31 +715,32 @@ def process_weights_after_loading(self, layer: Module) -> None:
715715
dq_weight, max_w13_scales[expert_id])
716716
start += shard_size
717717

718-
if aiter_moe_enabled():
719-
if aiter_2stage_moe_enabled():
718+
if envs.VLLM_USE_AITER_MOE:
719+
if envs.VLLM_USE_AITER_2STAGE_MOE:
720720
max_w13_scales = max_w13_scales.unsqueeze(-1)
721721
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1)
722-
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
723-
layer.w13_weight, layout=(32, 32)),
724-
requires_grad=False)
725-
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
726-
layer.w2_weight, layout=(32, 32)),
727-
requires_grad=False)
728722
else:
729723
max_w13_scales = max_w13_scales.unsqueeze(-1).unsqueeze(
730724
-1).expand((-1, layer.w13_weight.shape[1], -1))
731-
w2_scales = layer.w2_weight_scale.data.unsqueeze(
732-
-1).unsqueeze(-1).expand(
733-
(-1, layer.w2_weight.shape[1], -1))
725+
w2_scales = layer.w2_weight_scale.data.unsqueeze(-1).unsqueeze(
726+
-1).expand((-1, layer.w2_weight.shape[1], -1))
727+
728+
layer.w2_weight_scale = torch.nn.Parameter(
729+
w2_scales.contiguous(), requires_grad=False)
730+
if envs.VLLM_USE_AITER_2STAGE_MOE:
731+
layer.w13_weight = torch.nn.Parameter(
732+
shuffle_weight(layer.w13_weight, layout=(32, 32)),
733+
requires_grad=False)
734+
layer.w2_weight = torch.nn.Parameter(
735+
shuffle_weight(layer.w2_weight, layout=(32, 32)),
736+
requires_grad=False)
737+
else:
734738
layer.w13_weight = torch.nn.Parameter(shuffle_weight(
735739
layer.w13_weight),
736740
requires_grad=False)
737741
layer.w2_weight = torch.nn.Parameter(shuffle_weight(
738742
layer.w2_weight),
739743
requires_grad=False)
740-
741-
layer.w2_weight_scale = torch.nn.Parameter(
742-
w2_scales.contiguous(), requires_grad=False)
743744
layer.w13_weight_scale = torch.nn.Parameter(
744745
max_w13_scales.contiguous(), requires_grad=False)
745746
return
@@ -775,15 +776,15 @@ def apply(
775776
e_score_correction_bias=e_score_correction_bias,
776777
)
777778

778-
if aiter_moe_enabled():
779-
if aiter_2stage_moe_enabled():
779+
if envs.VLLM_USE_AITER_MOE:
780+
if envs.VLLM_USE_AITER_2STAGE_MOE:
780781
return ck_moe_2stages(a1=x,
781-
w1=layer.w13_weight,
782-
w2=layer.w2_weight,
783-
topk_weight=topk_weights,
784-
topk_ids=topk_ids,
785-
fc1_scale=layer.w13_weight_scale,
786-
fc2_scale=layer.w2_weight_scale)
782+
w1=layer.w13_weight,
783+
w2=layer.w2_weight,
784+
topk_weight=topk_weights,
785+
topk_ids=topk_ids,
786+
fc1_scale=layer.w13_weight_scale,
787+
fc2_scale=layer.w2_weight_scale)
787788

788789
return asm_moe(
789790
hidden_states=x,

0 commit comments

Comments
 (0)