Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def rocm_aiter_tuned_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -45,6 +44,49 @@ def rocm_aiter_tuned_gemm_fake(
return torch.empty((m, n), dtype=out_dtype, device=input.device)


def rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float, x_scale: Optional[torch.Tensor]=None, y_scale_dtype: Optional[torch.dtype]=None,
q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# TODO: make q_dtype general
import aiter as rocm_aiter
assert y_scale_dtype is not None # TODO

if x_scale is None:
output = torch.empty(input.shape, dtype=q_dtype, device="cuda")
y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now
if residual is None:
residual_out = None
rocm_aiter.rmsnorm2d_fwd_with_dynamicquant(
output, input, y_scale, weight, variance_epsilon, model_sensitive
)
elif residual is not None:
residual_out = torch.empty_like(input)
rocm_aiter.rmsnorm2d_fwd_with_add_dynamicquant(
output, input, residual, residual_out, y_scale, weight, variance_epsilon, model_sensitive
)
else:
output = torch.empty(input.shape, dtype=q_dtype, device="cuda")
y_scale = torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda") #TODO: only per-token quant now
if residual is None:
residual_out = None
rocm_aiter.rmsnorm2d_fwd_with_smoothquant(
output, input, x_scale, y_scale, weight, variance_epsilon, model_sensitive
)
elif residual is not None:
residual_out = torch.empty_like(input)
rocm_aiter.rmsnorm2d_fwd_with_add_smoothquant(
output, input, residual, residual_out, x_scale, y_scale, weight, variance_epsilon
)
return output, residual_out, y_scale

def rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float, x_scale: Optional[torch.Tensor] = None, y_scale_dtype: Optional[torch.dtype] = None,
q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return torch.empty(input.shape, dtype=q_dtype, device="cuda"), torch.empty_like(input), torch.empty(input.shape[0], 1, dtype=y_scale_dtype, device="cuda")


if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_tuned_gemm",
Expand All @@ -54,6 +96,13 @@ def rocm_aiter_tuned_gemm_fake(
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add_quant",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_quant_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_quant_fake,
dispatch_key=current_platform.dispatch_key,
)

class aiter_ops:

Expand All @@ -73,4 +122,16 @@ def rocm_aiter_tuned_gemm(
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)
)

@staticmethod
def rocm_aiter_rmsnorm2d_fwd_with_add_quant(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float, x_scale: Optional[torch.Tensor] = None, y_scale_dtype: Optional[torch.dtype] = None,
q_dtype: torch.dtype=torch.float8_e4m3fnuz, model_sensitive: float=0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add_quant(
input, residual, weight, variance_epsilon, x_scale,
y_scale_dtype, q_dtype, model_sensitive
)

6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_ASMMOE: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_ROPE: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
Expand Down Expand Up @@ -589,6 +590,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1")),

# for fused rmsnorm and fp8 quant kernel from aiter
"VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT", "True").lower() in
("true", "1")),

# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
Expand Down
12 changes: 8 additions & 4 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)

def forward(
self, x: torch.Tensor
self, x: torch.Tensor, input_scale: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
if input_scale is not None:
output = self.quant_method.apply(self, x, bias, input_scale)
else:
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
Expand Down Expand Up @@ -532,13 +535,14 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
param.load_column_parallel_weight(loaded_weight=loaded_weight)

def forward(
self, input_
self, input_, input_scale: Optional[torch.Tensor]=None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
output_parallel = self.quant_method.apply(self, input_, bias, input_scale)
# for fused_rmsnorm_quant usage
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,8 @@ def create_weights(self, layer: torch.nn.Module,
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
Expand All @@ -577,7 +578,7 @@ def apply(self,
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
return scheme.apply_weights(layer, x, bias=bias, input_scale=input_scale)


class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_weights(self, *args, **kwargs):

@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]):
bias: Optional[torch.Tensor], input_scale: Optional[torch.Tensor]):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
Expand All @@ -41,6 +41,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
:parm input_scale: input scale used for fused_rmsnorm_quant

"""
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,14 @@ def create_weights(self, layer: torch.nn.Module,
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor]=None) -> torch.Tensor:

effective_input_scale = input_scale if input_scale is not None else layer.input_scale

return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
input_scale=effective_input_scale,
bias=bias)
18 changes: 15 additions & 3 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,19 @@ def process_weights_after_loading(self, layer: Module) -> None:
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Apply the FP8 linear transformation.

Args:
layer: The linear layer module
x: Input tensor
bias: Optional bias tensor
input_scale: Optional input scale tensor from fused_rmsnorm_quant
"""
# Use provided input_scale if available, otherwise use layer's input_scale
effective_input_scale = input_scale if input_scale is not None else layer.input_scale

if self.use_marlin:
return apply_fp8_marlin_linear(
Expand All @@ -415,7 +427,7 @@ def apply(self,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
input_scale=effective_input_scale,
bias=bias,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
Expand All @@ -425,7 +437,7 @@ def apply(self,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
input_scale=effective_input_scale,
bias=bias)


Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ def apply(
# so fallback to naive if per channel or per token
else:
if input.dtype != current_platform.fp8_dtype():

if not self.use_aiter_and_is_supported:
# Maybe apply padding to output, see comment in __init__
qinput, x_scale = ops.scaled_fp8_quant(
Expand All @@ -407,6 +406,9 @@ def apply(
else:
qinput, x_scale = input_2d, input_scale

if envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT and input_scale is not None:
qinput, x_scale = input_2d, input_scale

per_tensor_weights = (weight_scale.numel()
== 1) and weight_scale.dim() < 2
per_tensor_activations = (x_scale.numel()
Expand Down
84 changes: 67 additions & 17 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

import vllm.envs as envs
from vllm.platforms import current_platform
from vllm._aiter_ops import aiter_ops

def is_rocm_aiter_fused_rms_quant_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
and envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_FUSED_RMSNORM_QUANT

class DeepseekV2MLP(nn.Module):

Expand Down Expand Up @@ -291,12 +300,20 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_scale: Optional[torch.Tensor] = None ,
quant_hidden_states: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
if is_rocm_aiter_fused_rms_quant_enabled():
q, residual_out, q_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(q, residual=None,
weight=self.q_a_layernorm.weight, variance_epsilon=self.q_a_layernorm.variance_epsilon,
x_scale=None, y_scale_dtype=torch.float32)
q = self.q_b_proj(q, q_scale)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
Expand All @@ -306,8 +323,14 @@ def forward(
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
if is_rocm_aiter_fused_rms_quant_enabled():
kv_a, residual_out, kv_a_y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(kv_a, residual=None,
weight=self.kv_a_layernorm.weight, variance_epsilon=self.kv_a_layernorm.variance_epsilon,
x_scale=None, y_scale_dtype=torch.float32)
kv = self.kv_b_proj(kv_a, kv_a_y_scale)[0]
else:
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
Expand Down Expand Up @@ -467,14 +490,19 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
quant_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
if is_rocm_aiter_fused_rms_quant_enabled():
ckq = self.q_a_proj(quant_hidden_states, input_scale)[0]
else:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
Expand Down Expand Up @@ -553,16 +581,38 @@ def forward(
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if is_rocm_aiter_fused_rms_quant_enabled():
if residual is None:
q_hidden_states, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states,
residual=None, weight=self.input_layernorm.weight,
variance_epsilon=self.input_layernorm.variance_epsilon,
x_scale=None, y_scale_dtype=torch.float32)
residual = hidden_states
else: # tmp fix for enable cuda graph
q_hidden_states, residual_out, y_scale = aiter_ops.rocm_aiter_rmsnorm2d_fwd_with_add_quant(hidden_states,
residual=residual, weight=self.input_layernorm.weight,
variance_epsilon=self.input_layernorm.variance_epsilon,
x_scale=None, y_scale_dtype=torch.float32)
residual = residual_out
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if is_rocm_aiter_fused_rms_quant_enabled():
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_scale=y_scale,
quant_hidden_states = q_hidden_states
)
else:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)

if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
Expand Down