Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions csrc/sm90/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,19 @@ mha_fwd_kvcache_mla(
const float softmax_scale,
bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
const at::Tensor &num_splits, // batch_size + 1
c10::optional<const at::Tensor> &descale_q, // batch_size
c10::optional<const at::Tensor> &descale_k // batch_size
) {
// Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90);

// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
auto q_dtype = q.scalar_type();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf||
q_dtype == torch::kFloat8_e4m3fn, "Unsupported dtype for query tensor");
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
Expand Down Expand Up @@ -106,7 +109,7 @@ mha_fwd_kvcache_mla(
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 512 is supported");

const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
Expand All @@ -115,6 +118,20 @@ mha_fwd_kvcache_mla(
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (q_dtype == torch::kFloat8_e4m3fn) {
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
auto descale_q_value = descale_q.value();
auto descale_k_value = descale_k.value();
CHECK_DEVICE(descale_q_value);
CHECK_DEVICE(descale_k_value);
TORCH_CHECK(descale_q_value.stride(-1) == 1);
TORCH_CHECK(descale_k_value.stride(-1) == 1);
TORCH_CHECK(descale_q_value.dtype() == torch::kFloat);
TORCH_CHECK(descale_k_value.dtype() == torch::kFloat);
CHECK_SHAPE(descale_q_value, 1);
CHECK_SHAPE(descale_k_value, 1);
}

if (seqlen_q_ori == 1) { is_causal = false; }

const int num_q_heads_per_hk = num_heads_q / num_heads_k;
Expand All @@ -133,7 +150,8 @@ mha_fwd_kvcache_mla(
at::cuda::CUDAGuard device_guard{(char)q.get_device()};

auto opts = q.options();
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; // Kernel already supports half, but need change python api for output dtype
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type));
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse);

Expand All @@ -152,6 +170,12 @@ mha_fwd_kvcache_mla(
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
if (q_dtype == torch::kFloat8_e4m3fn) {
// params.descale_q = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used
// params.descale_k = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used
params.descale_q_ptr = reinterpret_cast<float*>(descale_q.value().data_ptr());
params.descale_k_ptr = reinterpret_cast<float*>(descale_k.value().data_ptr());
}
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
Expand Down Expand Up @@ -197,6 +221,9 @@ mha_fwd_kvcache_mla(
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif
} else if (q_dtype == torch::kFloat8_e4m3fn) { // Output default dtype is bfloat16_t, can support half.
run_flash_splitkv_mla_kernel<cutlass::float_e4m3_t, cutlass::bfloat16_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
Expand Down
83 changes: 83 additions & 0 deletions csrc/sm90/kernels/fp8_transpose_v.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/**
* ref to Fa3's SmemTranspose64x64:
* https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26
*/

#pragma once
using namespace cute;

template <int kBlockN, int kHeadDim>
struct SmemTransposeFp8_64x64 {
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));

using Element = cutlass::float_e4m3_t;
using TransposeShapeAtomV = Shape<_64, _64>;
using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));

// for fp8 in-kernel transpose -- src layout
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{})));
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));

// For fp8, this is the memory transpose.
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
using SmemLayoutVt =
decltype(tile_to_shape(SmemLayoutAtomVt{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));

// for fp8 in-kernel transpose -- dst layout
using SmemLayoutVtTrans = decltype(composition(
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));


using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
using ldsm_value_shape = Shape<_2, _8, _2, _1>;
using ldsm_value_stride = Stride<_2, _4, _1, _0>;
using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
Layout<ldsm_value_shape, ldsm_value_stride>{}));
TiledCopyLDSM tiled_copy_ldsm;

using stsm_thread_shape = Shape<_4, _1, _8, _4>;
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
using stsm_value_shape = Shape<_4, _4, _2, _1>;
using stsm_value_stride = Stride<_1, _8, _4, _0>;

using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<stsm_thread_shape>{},
Layout<stsm_value_shape, stsm_value_stride>{}));
TiledCopySTSM tiled_copy_stsm;

template <class SmemTensor, class SmemTensorOut>
CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) {
using namespace cute;

auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);

auto tXsX = thr_copy_ldsm.partition_S(s_in);
auto tXrX = make_tensor<Element>(shape(tXsX));
auto tXsX_out = thr_copy_stsm.partition_D(s_out);

cute::copy(tiled_copy_ldsm, tXsX, tXrX);

auto data = tXrX.data();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size(tXrX); n += 8) {
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
auto upper = data_32bit[0];
auto lower = data_32bit[1];
data_32bit[0] = __byte_perm(upper, lower, 0x6420);
data_32bit[1] = __byte_perm(upper, lower, 0x7531);
}

cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
}
};
3 changes: 3 additions & 0 deletions csrc/sm90/kernels/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ struct Flash_fwd_mla_params {
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
// float descale_q, descale_k; // cpu scalar faster ,but need change sglang api used
float* __restrict__ descale_q_ptr = nullptr;
float* __restrict__ descale_k_ptr = nullptr;

void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
Expand Down
6 changes: 3 additions & 3 deletions csrc/sm90/kernels/splitkv_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
}


template<typename InputT>
template<typename InputT, typename OutputT = InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
using T = Traits<InputT>;
auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
Expand Down Expand Up @@ -1347,8 +1347,8 @@ void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t str
CHECK_CUDA_KERNEL_LAUNCH();
}

template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t, cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);

#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_splitkv_mla_kernel<cutlass::half_t, cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#endif
2 changes: 1 addition & 1 deletion csrc/sm90/kernels/splitkv_mla.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

#include "params.h"

template<typename InputT>
template<typename InputT, typename OutputT = InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);
Loading