@@ -805,14 +805,18 @@ def forward(
805
805
"""Forward pass with FlashAttention.
806
806
807
807
Args:
808
- query: shape = [num_tokens, num_heads, head_size]
809
- key: shape = [num_tokens, num_kv_heads, head_size]
810
- value: shape = [num_tokens, num_kv_heads, head_size]
811
- output: shape = [num_tokens, num_heads, head_size]
812
- kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
808
+ layer: Attention layer instance.
809
+ q: Query tensor with shape = [num_tokens, num_heads, head_size]
810
+ k: Key tensor with shape = [num_tokens, num_kv_heads, head_size]
811
+ v: Value tensor with shape = [num_tokens, num_kv_heads, head_size]
812
+ kv_cache: KV cache tensor with shape
813
+ [2, num_blocks, block_size, num_kv_heads, head_size].
813
814
NOTE: kv_cache will be an empty tensor with shape [0]
814
815
for profiling run.
815
816
attn_metadata: Metadata for attention.
817
+ output: Output tensor with shape [num_tokens, num_heads, head_size]
818
+ output_scale: Optional output scale tensor.
819
+ output_block_scale: Optional output block scale tensor.
816
820
NOTE: It in-place updates the output tensor.
817
821
NOTE: FP8 quantization, flash-attn expect the size of
818
822
{q,k,v}_descale to be (num_sequences, num_kv_heads).
0 commit comments