diff --git a/csrc/sm90/flash_api.cpp b/csrc/sm90/flash_api.cpp index a87e1ab..6e6b69b 100644 --- a/csrc/sm90/flash_api.cpp +++ b/csrc/sm90/flash_api.cpp @@ -68,7 +68,9 @@ mha_fwd_kvcache_mla( const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits // batch_size + 1 + const at::Tensor &num_splits, // batch_size + 1 + c10::optional &descale_q, // batch_size + c10::optional &descale_k // batch_size ) { // Check the architecture auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -76,8 +78,9 @@ mha_fwd_kvcache_mla( TORCH_CHECK(is_sm90); // Check data types - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); + auto q_dtype = q.scalar_type(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf|| + q_dtype == torch::kFloat8_e4m3fn, "Unsupported dtype for query tensor"); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); @@ -106,7 +109,7 @@ mha_fwd_kvcache_mla( const int num_heads_q = sizes[2]; const int head_size_k = sizes[3]; TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); - TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported"); + TORCH_CHECK(head_size_v == 512, "Only head_size_v == 512 is supported"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); @@ -115,6 +118,20 @@ mha_fwd_kvcache_mla( TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (q_dtype == torch::kFloat8_e4m3fn) { + TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8"); + auto descale_q_value = descale_q.value(); + auto descale_k_value = descale_k.value(); + CHECK_DEVICE(descale_q_value); + CHECK_DEVICE(descale_k_value); + TORCH_CHECK(descale_q_value.stride(-1) == 1); + TORCH_CHECK(descale_k_value.stride(-1) == 1); + TORCH_CHECK(descale_q_value.dtype() == torch::kFloat); + TORCH_CHECK(descale_k_value.dtype() == torch::kFloat); + CHECK_SHAPE(descale_q_value, 1); + CHECK_SHAPE(descale_k_value, 1); + } + if (seqlen_q_ori == 1) { is_causal = false; } const int num_q_heads_per_hk = num_heads_q / num_heads_k; @@ -133,7 +150,8 @@ mha_fwd_kvcache_mla( at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts); + auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; // Kernel already supports half, but need change python api for output dtype + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); CHECK_CONTIGUOUS(softmax_lse); @@ -152,6 +170,12 @@ mha_fwd_kvcache_mla( params.d_v = head_size_v; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); + if (q_dtype == torch::kFloat8_e4m3fn) { + // params.descale_q = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used + // params.descale_k = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used + params.descale_q_ptr = reinterpret_cast(descale_q.value().data_ptr()); + params.descale_k_ptr = reinterpret_cast(descale_k.value().data_ptr()); + } // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = kcache.data_ptr(); @@ -197,6 +221,9 @@ mha_fwd_kvcache_mla( run_flash_splitkv_mla_kernel(params, stream); run_flash_mla_combine_kernel(params, stream); #endif + } else if (q_dtype == torch::kFloat8_e4m3fn) { // Output default dtype is bfloat16_t, can support half. + run_flash_splitkv_mla_kernel(params, stream); + run_flash_mla_combine_kernel(params, stream); } else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } diff --git a/csrc/sm90/kernels/fp8_transpose_v.h b/csrc/sm90/kernels/fp8_transpose_v.h new file mode 100644 index 0000000..5a7e6ff --- /dev/null +++ b/csrc/sm90/kernels/fp8_transpose_v.h @@ -0,0 +1,83 @@ +/** + * ref to Fa3's SmemTranspose64x64: + * https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26 +*/ + +#pragma once +using namespace cute; + +template +struct SmemTransposeFp8_64x64 { + static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); + + using Element = cutlass::float_e4m3_t; + using TransposeShapeAtomV = Shape<_64, _64>; + using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- src layout + using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}))); + using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + + // For fp8, this is the memory transpose. + using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- dst layout + using SmemLayoutVtTrans = decltype(composition( + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); + using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); + using SmemShapeSTSM = Shape, Shape<_16, _4>>; + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); + using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; + + using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; diff --git a/csrc/sm90/kernels/params.h b/csrc/sm90/kernels/params.h index 3b4e254..71d2fce 100644 --- a/csrc/sm90/kernels/params.h +++ b/csrc/sm90/kernels/params.h @@ -14,6 +14,9 @@ struct Flash_fwd_mla_params { int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k bool is_causal; float scale_softmax, scale_softmax_log2; + // float descale_q, descale_k; // cpu scalar faster ,but need change sglang api used + float* __restrict__ descale_q_ptr = nullptr; + float* __restrict__ descale_k_ptr = nullptr; void *__restrict__ q_ptr; void *__restrict__ k_ptr; diff --git a/csrc/sm90/kernels/splitkv_mla.cu b/csrc/sm90/kernels/splitkv_mla.cu index 5e1fded..4a79dd4 100644 --- a/csrc/sm90/kernels/splitkv_mla.cu +++ b/csrc/sm90/kernels/splitkv_mla.cu @@ -1270,7 +1270,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params } -template +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { using T = Traits; auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); @@ -1347,8 +1347,8 @@ void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t str CHECK_CUDA_KERNEL_LAUNCH(); } -template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); #ifndef FLASH_MLA_DISABLE_FP16 -template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); #endif diff --git a/csrc/sm90/kernels/splitkv_mla.h b/csrc/sm90/kernels/splitkv_mla.h index 479fb50..c0d830e 100644 --- a/csrc/sm90/kernels/splitkv_mla.h +++ b/csrc/sm90/kernels/splitkv_mla.h @@ -2,5 +2,5 @@ #include "params.h" -template +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/sm90/kernels/splitkv_mla_fp8.cu b/csrc/sm90/kernels/splitkv_mla_fp8.cu new file mode 100644 index 0000000..30079d1 --- /dev/null +++ b/csrc/sm90/kernels/splitkv_mla_fp8.cu @@ -0,0 +1,1338 @@ +#include + +#include "params.h" +#include "utils.h" +#include "config.h" +#include "traits.h" + +using namespace cute; +using cutlass::arch::NamedBarrier; + +// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking +// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2) +// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM +static constexpr float MAX_INIT_VAL_SM = -1e30f; +static constexpr float MAX_INIT_VAL = -1e33f; + + +CUTLASS_DEVICE int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { + // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx + // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a + int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); + return row_idx; +} + +// Launch TMA copy for a range of KV tile +// A tile has a shape of PAGE_BLOCK_SIZE (64) x 64 +template< + int START_HEAD_DIM_TILE_IDX, + int END_HEAD_DIM_TILE_IDX, + typename TMA_K_OneTile, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +CUTLASS_DEVICE void launch_kv_tiles_copy_tma( + Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled + TMA_K_OneTile &tma_K, + TMABarrier* barriers_K, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + auto thr_tma = tma_K.get_slice(_0{}); + Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); + Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int{}); + cute::copy(tma_K.with(reinterpret_cast(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV); + if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { + launch_kv_tiles_copy_tma(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup); + } + } +} + +// Prefetch some KV tiles +// Currently this is not used because it leads to performance degradation +template< + int START_HEAD_DIM_TILE_IDX, + int END_HEAD_DIM_TILE_IDX, + typename TMA_K_OneTile, + typename Engine0, typename Layout0 +> +CUTLASS_DEVICE void prefetch_kv_tiles( + Tensor const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + TMA_K_OneTile &tma_K, + int idx_in_warpgroup +) { + if (idx_in_warpgroup == 0) { + auto thr_tma = tma_K.get_slice(_0{}); + Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int{}); + cute::prefetch(tma_K, cur_gKV); + if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) { + prefetch_kv_tiles(gKV, tma_K, idx_in_warpgroup); + } + } +} + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h +// * Copyright (c) 2024, Tri Dao. +template +CUTLASS_DEVICE void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + + +// Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64) +// The Q-tile should be in shared memory +template< + typename T, + typename TiledMMA, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +CUTLASS_DEVICE void qkt_gemm_one_tile_sQ( + TiledMMA &tiled_mma, + Tensor const &thr_mma_sQ_tile, // (MMA, 1, 2) + Tensor const &thr_mma_sKV_tile, // (MMA, 1, 2) + Tensor &rP, // ((4, 2, 2), 1, 2) + typename T::InputT* sK_ptr, + typename T::InputT* sVt_ptr, + TMABarrier* barrier, + bool &cur_phase, + int idx_in_warpgroup, + int tile_idx, + int v_named_barrier, + int valid_window_size +) { + if (idx_in_warpgroup == 0) { + barrier->arrive_and_expect_tx(64*64*sizeof(typename T::InputT)); + } + barrier->wait(cur_phase ? 1 : 0); + if(tile_idx != 8){ // V: (0~7) tiles + if(valid_window_size>0 && valid_window_size(sK_ptr + tile_idx*64*64, valid_window_size, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, v_named_barrier); + } + fp8_transpose_v(sK_ptr, sVt_ptr, tile_idx); // transpose sK + cutlass::arch::fence_view_async_shared(); + } + + warpgroup_fence_operand(rP); + warpgroup_arrive(); + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP); + warpgroup_commit_batch(); + warpgroup_fence_operand(rP); +} + + +// Pipelined TMA wait and Q K^T gemm +// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows: +// - Wait for the 0-th tile to be ready using `barrier.wait()` +// - Compute Q K^T for the 0-th tile +// - Wait for the 1-st tile to be ready +// - Compute Q K^T for the 1-st tile +// ... +// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation +template< + typename T, // TraitsFP8 + int PHASE_IDX, // See comments in the code + int pipeline_id, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +CUTLASS_DEVICE void warpgroup_cooperative_qkt_gemm( + Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K) + Tensor &rP, // ((4, 2, 2), 1, 2) + typename T::InputT* sVt_ptr, + TMABarrier* barriers, + bool &cur_phase, + int idx_in_warpgroup, + int v_named_barrier, + int valid_window_size +) { + Tensor sQ_tiled = flat_divide(sQ, Shape, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9) + Tensor sKV_tiled = flat_divide(sKV, Shape, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9) + TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){}; + ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup); + Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 2, 9) + Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 2, 9) + + #define QKT_GEMM_ONE_TILE(TILE_IDX) \ + qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int{}), thr_mma_sKV_tiled(_, _, _, Int{}), rP, sKV.data().get().get(), sVt_ptr, barriers + TILE_IDX, cur_phase, idx_in_warpgroup, TILE_IDX, v_named_barrier, valid_window_size); + if constexpr (PHASE_IDX == 0) { + // In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; + QKT_GEMM_ONE_TILE(0); + QKT_GEMM_ONE_TILE(1); + QKT_GEMM_ONE_TILE(2); + QKT_GEMM_ONE_TILE(3); + } else if constexpr (PHASE_IDX == 1) { + // In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero; + QKT_GEMM_ONE_TILE(4); + QKT_GEMM_ONE_TILE(5); + QKT_GEMM_ONE_TILE(6); + QKT_GEMM_ONE_TILE(7); + QKT_GEMM_ONE_TILE(8); + QKT_GEMM_ONE_TILE(0); + QKT_GEMM_ONE_TILE(1); + QKT_GEMM_ONE_TILE(2); + QKT_GEMM_ONE_TILE(3); + cur_phase ^= 1; + } else { + // In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles + static_assert(PHASE_IDX == 2); + tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One; + QKT_GEMM_ONE_TILE(4); + QKT_GEMM_ONE_TILE(5); + QKT_GEMM_ONE_TILE(6); + QKT_GEMM_ONE_TILE(7); + QKT_GEMM_ONE_TILE(8); + cur_phase ^= 1; + } +} + +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +CUTLASS_DEVICE void warpgroup_cooperative_qkt_gemm_no_pipeline( + Tensor &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K) + Tensor &rP, // ((4, 2, 2), 1, 2) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/32=18) + Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/32=18) + gemm(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP); +} + + +// Compute O += PV, where P resides in register +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +CUTLASS_DEVICE void warpgroup_cooperative_pv_gemm_localP( + Tensor &rP, // ((4, 2, 2), 1, 2), fragment A layout + Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) + Tensor &rO, // ((2, 2, 32), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor rP_retiled = make_tensor(rP.data(), Layout< + Shape, _1, _2>, + Stride, _0, _16> + >{}); + Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/32=2) + gemm(tiled_mma, rP_retiled, thr_mma_sKV_half, rO); +} + +// Compute O += PV, where P resides in shared memory +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +CUTLASS_DEVICE void warpgroup_cooperative_pv_gemm_remoteP( + Tensor &sP, + Tensor &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE) + Tensor &rO, // ((2, 2, 32), 1, 1) + int idx_in_warpgroup +) { + TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){}; + ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); + Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP); + Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/32=2) + + gemm(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO); +} + + +template< + typename T, + bool DO_OOB_FILLING, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4 +> +CUTLASS_DEVICE void wg0_bunch_0( + Tensor &rP0, // ((4, 2, 2), 1, 2) + Tensor &rO0, // ((2, 2, 32), 1, 1) + Tensor &sScale0, // (BLOCK_SIZE_M) + Tensor &sM, // (BLOCK_SIZE_M) + float rL[2], + int rRightBorderForQSeq[2], + float scale_softmax_log2, + int start_token_idx, + int idx_in_warpgroup +) { + // This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png) + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + + // Mask, and get row-wise max + float cur_max = MAX_INIT_VAL; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + if constexpr (DO_OOB_FILLING) { + int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; + rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL; + rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL; + } + cur_max = max(cur_max, max(rP0(i), rP0(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + + // Update sM and sL + cur_max *= scale_softmax_log2; + float new_max = max(sM(row_idx), cur_max); + float scale_for_old = exp2f(sM(row_idx) - new_max); + __syncwarp(); // Make sure all reads have finished before updating sM + if (idx_in_warpgroup%4 == 0) { + sScale0(row_idx) = scale_for_old; + sM(row_idx) = new_max; + } + + // Scale-O + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { + rO0(i) *= scale_for_old; + rO0(i+1) *= scale_for_old; + } + + // Scale, exp, and get row-wise expsum + float cur_sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max); + rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max); + //rPb(i) = (typename T::InputT)rP0(i); + //rPb(i+1) = (typename T::InputT)rP0(i+1); + cur_sum += rP0(i) + rP0(i+1); + } + rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum; + } +} + + +template< + typename T, + bool IS_BLK0_LAST, + bool IS_BLK1_LAST, + bool IS_BLK2_LAST, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine4, typename Layout4, + typename Engine5, typename Layout5 +> +CUTLASS_DEVICE void wg1_bunch_0( + Tensor &sScale1, // (BLOCK_SIZE_M) + Tensor &rO1, // ((2, 2, 32), 1, 1) + Tensor &sM, // (BLOCK_SIZE_M) + float rL[2], + int rRightBorderForQSeq[2], + Tensor const &sScale0, // (BLOCK_SIZE_M) + Tensor &rP1, // ((4, 2, 2), 1, 2) + float scale_softmax_log2, + int start_token_idx, + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + + // Mask, and get row-wise max + float cur_max = MAX_INIT_VAL; + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { + if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) { + // Need to apply the mask when either this block is the last one, or + // the next block is the last one (because of the causal mask) + int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2; + rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL; + rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL; + } else if constexpr (IS_BLK0_LAST) { + rP1(i) = rP1(i+1) = MAX_INIT_VAL; + } + cur_max = max(cur_max, max(rP1(i), rP1(i+1))); + } + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1)); + cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2)); + cur_max *= scale_softmax_log2; + + float old_max = sM(row_idx); + float new_max = max(old_max, cur_max); + float scale_for_old = exp2f(old_max - new_max); + __syncwarp(); + if (idx_in_warpgroup%4 == 0) { + sM(row_idx) = new_max; + sScale1(row_idx) = scale_for_old; + } + + // Scale, exp, and get row-wise expsum + float cur_sum = 0; + if constexpr (!IS_BLK0_LAST) { + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) { + rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max); + rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max); + cur_sum += rP1(i) + rP1(i+1); + } + } + + // Scale O + float cur_scale_for_o1 = scale_for_old * sScale0(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) { + rO1(i) *= cur_scale_for_o1; + rO1(i+1) *= cur_scale_for_o1; + } + + // Update rL + rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum; + } +} + + +// Save rPb (64x64, bfloat16/half/fp8) to sP using the stmatrix instruction +template< + typename T, + typename Engine0, typename Layout0 +> +CUTLASS_DEVICE void save_rPb_to_sP( + Tensor &rPb, + uint16_t* sP_ptr, + int idx_in_warpgroup +) { + using AtomSt = Copy_Atom; + using ThrLayout = Layout>; + using Val_8 = Layout>; + auto copy_r2s = make_tiled_copy(AtomSt{}, ThrLayout{}, Val_8{}); + auto thr_w = copy_r2s.get_slice(idx_in_warpgroup); + + Tensor sPi0 = make_tensor(make_smem_ptr(sP_ptr), typename T::SmemLayoutPi{}); + Tensor sPi0_tile = thr_w.partition_D(sPi0); + Tensor tXrXi0 = make_tensor(make_rmem_ptr(reinterpret_cast(rPb.data())), typename T::RegLayout{}); + copy(copy_r2s, tXrXi0, sPi0_tile); + + Tensor sPi1 = make_tensor(make_smem_ptr(sP_ptr + 1024), typename T::SmemLayoutPi{}); + Tensor sPi1_tile = thr_w.partition_D(sPi1); + Tensor tXrXi1 = make_tensor(make_rmem_ptr(reinterpret_cast(rPb.data()) + 8), typename T::RegLayout{}); + copy(copy_r2s, tXrXi1, sPi1_tile); +} + + +template< + typename T, + typename Engine0, typename Layout0 +> +CUTLASS_DEVICE void load_sP_to_rPb( + uint16_t* sP_ptr, + Tensor &rPb, + int idx_in_warpgroup +) { + using AtomLd = Copy_Atom; + using ThrLayout = Layout>; + using Val_8 = Layout>; + auto copy_s2r = make_tiled_copy(AtomLd{}, ThrLayout{}, Val_8{}); + auto thr_w = copy_s2r.get_slice(idx_in_warpgroup); + + Tensor sPi0 = make_tensor(make_smem_ptr(sP_ptr), typename T::SmemLayoutPi{}); + Tensor sPi0_tile = thr_w.partition_S(sPi0); + Tensor tXrXi0 = make_tensor(make_rmem_ptr(reinterpret_cast(rPb.data()) + 0), typename T::RegLayout{}); + copy(copy_s2r, sPi0_tile, tXrXi0); + + Tensor sPi1 = make_tensor(make_smem_ptr(sP_ptr + 1024), typename T::SmemLayoutPi{}); + Tensor sPi1_tile = thr_w.partition_S(sPi1); + Tensor tXrXi1 = make_tensor(make_rmem_ptr(reinterpret_cast(rPb.data()) + 8), typename T::RegLayout{}); + copy(copy_s2r, sPi1_tile, tXrXi1); +} + + +// Rescale rP0 and save the result to rPb +template< + typename T, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2 +> +CUTLASS_DEVICE void wg0_scale_rP0( + Tensor const &sScale1, // (BLOCK_M) + Tensor const &rP0, // ((4, 2, 2), 1, 2) + Tensor &rPb, // ((4, 2, 2), 1, 2) + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + float scale_factor = sScale1(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) { + rPb(i) = (typename T::InputT)(rP0(i)*scale_factor); + rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor); + } + } +} + + +// Rescale rO0 according to sScale1 +template< + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +CUTLASS_DEVICE void wg0_rescale_rO0( + Tensor &rO0, + Tensor &sScale1, + float rL[2], + int idx_in_warpgroup +) { + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + float scale_factor = sScale1(row_idx); + CUTLASS_PRAGMA_UNROLL + for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) { + rO0(i) *= scale_factor; + rO0(i+1) *= scale_factor; + } + rL[local_row_idx] *= scale_factor; + } +} + + +// Store O / OAccum +template< + typename T, + bool IS_NO_SPLIT, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1 +> +CUTLASS_DEVICE void store_o( + Tensor &rO, //((_2,_2,_32),_1,_1):((_1,_2,_4),_0,_0) + Tensor &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V) + float rL[2], + char* sO_addr, + TMAParams &tma_params, + int batch_idx, + int k_head_idx, + int m_block_idx, + int num_valid_seq_q, + int warpgroup_idx, + int idx_in_warpgroup +) { + using OutputT = typename T::OutputT; + if constexpr (IS_NO_SPLIT) { + // Should convert the output to bfloat16 / float16, and save it to O + Tensor sOutputBuf = make_tensor(make_smem_ptr((OutputT*)sO_addr), tile_to_shape( + GMMA::Layout_K_SW128_Atom{}, + Shape, Int>{} + )); + + Tensor rOb = make_tensor_like(rO); + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); ++idx) { + rOb(idx) = (OutputT)(rO(idx) / rL[idx%4 >= 2]); + } + + Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx)); + TiledCopy r2s_tiled_copy = make_tiled_copy_C( + Copy_Atom{}, + (typename T::TiledMMA_PV_LocalP){} + ); + ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup); + Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb); + Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf); + cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf); + cutlass::arch::fence_view_async_shared(); + + __syncthreads(); + + if (threadIdx.x == 0) { + Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) + auto thr_tma = tma_params.tma_O.get_slice(_0{}); + Tensor my_tma_gO = flat_divide(tma_gO, Shape, Int>{})(_, _, m_block_idx, _0{}); + cute::copy( + tma_params.tma_O, + thr_tma.partition_S(sOutputBuf), + thr_tma.partition_D(my_tma_gO) + ); + cute::tma_store_arrive(); + } + } else { + // Should save the result to OAccum + Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout< + Shape<_64, _512>, + Stride, _1> // We use stride = 520 here to avoid bank conflict + >{}); + + CUTLASS_PRAGMA_UNROLL + for (int idx = 0; idx < size(rO); idx += 2) { + int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0); + int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8; + *(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 { + rO(idx) / rL[idx%4 >= 2], + rO(idx+1) / rL[idx%4 >= 2], + }; + } + cutlass::arch::fence_view_async_shared(); + + __syncthreads(); + + int row = threadIdx.x; + if (row < num_valid_seq_q) { + SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float)); + cute::tma_store_arrive(); + } + } +} + +template< + typename T, + typename TmaParams, typename Tensor0 +> +CUTLASS_DEVICE void launch_q_copy( + TmaParams const &tma_params, + int batch_idx, + int m_block_idx, + int k_head_idx, + Tensor0 &sQ, + TMABarrier* barrier_Q +) { + if (threadIdx.x == 0) { + Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM) + auto thr_tma = tma_params.tma_Q.get_slice(_0{}); + Tensor my_tma_gQ = flat_divide(tma_gQ, Shape, Int>{})(_, _, m_block_idx, _0{}); + cute::copy( + tma_params.tma_Q.with(reinterpret_cast(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), + thr_tma.partition_S(my_tma_gQ), + thr_tma.partition_D(sQ) + ); + barrier_Q->arrive_and_expect_tx(64*576*sizeof(typename T::InputT)); + } +} + +template +__device__ __forceinline__ +auto get_half_V(Tensorconst& sV) { + return flat_divide(sV, Shape,Int>{})(_,_,Int<(int)IS_R>{},_0{}); +} + + +template< + typename T, + bool IS_BLK0_LAST, // "BLK0" means block_idx+0, "BLK1" means block_idx+1, ... + bool IS_BLK1_LAST, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine6, typename Layout6, + typename Engine7, typename Layout7, + typename Engine8, typename Layout8, + typename Engine10, typename Layout10, + typename Engine11, typename Layout11, + typename Engine12, typename Layout12, + typename Engine13, typename Layout13 +> +CUTLASS_DEVICE void wg0_subroutine( + Tensor &tma_gK, + Tensor &sQ, + Tensor &sK0, + Tensor &sK1, + uint16_t* sP0_ptr, + uint16_t* sP1_ptr, + Tensor &sM, + Tensor &sScale0, + Tensor &sScale1, + Tensor &rP0, + Tensor &rO0, + Tensor &sV0, + Tensor &sV1, + typename T::InputT* sVt0_ptr, + float rL[2], + int rRightBorderForQSeq[2], + TMABarrier barriers_K0[9], + TMABarrier barriers_K1[9], + bool &cur_phase_K0, + const TMAParams &tma_params, + const Flash_fwd_mla_params ¶ms, + int* block_table_ptr, + int seqlen_k, + int block_idx, + int end_block_idx, + int idx_in_warpgroup +) { + int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; + #define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx))) + int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); + int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); + + Tensor sV0L = get_half_V(sV0); + Tensor sV1L = get_half_V(sV1); + + //Tensor rPb = make_tensor(Shape, _1, _2>{}); //k=64/32 + // Calc P0 = softmax(P0) + + wg0_bunch_0(rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready); + + permute_Cregs_128_to_64(rP0); + cute::warpgroup_fence_operand(rP0); + + Tensor tOrP_acc = make_tensor(rP0.data(), Layout< + Shape, _1, _2>, + Stride, _0, _16> + >{}); + Tensor rPb = make_tensor_like(tOrP_acc); + + convert_type_out(tOrP_acc, rPb); + + // Issue rO0 += rPb @ sV0L + warpgroup_cooperative_pv_gemm_localP(rPb, sV0L, rO0, idx_in_warpgroup); + + // Wait for rO0, launch TMA for the next V0L + cute::warpgroup_wait<0>(); + + // Wait for warpgroup 1, rescale P0, notify warpgroup 1 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready); + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + // Put it here seems to be faster, don't know why + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); + } + wg0_scale_rP0(sScale1, rP0, rPb, idx_in_warpgroup); + save_rPb_to_sP(rPb, sP0_ptr, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready); + + // Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L + if constexpr (!IS_BLK0_LAST) { + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); + wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup); + //warpgroup_cooperative_pv_gemm_remoteP(sP1, sV1L, rO0, idx_in_warpgroup); // replace + load_sP_to_rPb(sP1_ptr, rPb, idx_in_warpgroup); + warpgroup_cooperative_pv_gemm_localP(rPb, sV1L, rO0, idx_in_warpgroup); + } + + // Issue P0 = Q @ K0^T + // Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline. + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, sVt0_ptr, barriers_K0, cur_phase_K0, idx_in_warpgroup, NamedBarriers::sV0ZeroReady, seqlen_k - (block_idx + 2) * T::PAGE_BLOCK_SIZE); + } + + // Wait for rO0 += rPb @ sV1L, launch TMA + if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) { + cute::warpgroup_wait<4>(); + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); + } + + // Issue P0 = Q @ K0^T + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + warpgroup_cooperative_qkt_gemm(sQ, sK0, rP0, sVt0_ptr, barriers_K0, cur_phase_K0, idx_in_warpgroup, NamedBarriers::sV0ZeroReady, seqlen_k - (block_idx + 2) * T::PAGE_BLOCK_SIZE); + } + + // Wait for P0 = Q @ K0^T + cute::warpgroup_wait<0>(); +} + + +template< + typename T, + bool IS_BLK0_LAST, + bool IS_BLK1_LAST, + bool IS_BLK2_LAST, + typename TMAParams, + typename Engine0, typename Layout0, + typename Engine1, typename Layout1, + typename Engine2, typename Layout2, + typename Engine3, typename Layout3, + typename Engine6, typename Layout6, + typename Engine7, typename Layout7, + typename Engine8, typename Layout8, + typename Engine10, typename Layout10, + typename Engine11, typename Layout11, + typename Engine12, typename Layout12, + typename Engine13, typename Layout13 +> +CUTLASS_DEVICE void wg1_subroutine( + Tensor &tma_gK, + Tensor &sQ, + Tensor &sK0, + Tensor &sK1, + uint16_t* sP0_ptr, + uint16_t* sP1_ptr, + Tensor &sM, + Tensor &sScale0, + Tensor &sScale1, + Tensor &rP1, + Tensor &rO1, + Tensor &sV0, + Tensor &sV1, + typename T::InputT* sVt1_ptr, + float rL[2], + int rRightBorderForQSeq[2], + TMABarrier barriers_K0[9], + TMABarrier barriers_K1[9], + bool &cur_phase_K1, + const TMAParams &tma_params, + const Flash_fwd_mla_params ¶ms, + int* block_table_ptr, + int seqlen_k, + int block_idx, + int end_block_idx, + int idx_in_warpgroup +) { + int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE; + int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2); + int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3); + + Tensor sV0R = get_half_V(sV0); + Tensor sV1R = get_half_V(sV1); + + // Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready); + wg1_bunch_0(sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup); + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready); + permute_Cregs_128_to_64(rP1); + cute::warpgroup_fence_operand(rP1); + + Tensor tOrP_acc = make_tensor(rP1.data(), Layout< + Shape, _1, _2>, + Stride, _0, _16> + >{}); + Tensor rP1b = make_tensor_like(tOrP_acc); + + convert_type_out(tOrP_acc, rP1b); + cute::warpgroup_fence_operand(rP1b); + + // Save rPb to sP, and issue rO1 += rP1b @ sV1R + // We do this after notifying warpgroup 1, since both "saving rPb to sP" and "issuing" WGMMA are high-latency operations + if constexpr (!IS_BLK0_LAST) { + save_rPb_to_sP(rP1b, sP1_ptr, idx_in_warpgroup); + } + if constexpr (!IS_BLK0_LAST) { + warpgroup_cooperative_pv_gemm_localP(rP1b, sV1R, rO1, idx_in_warpgroup); + if constexpr (!IS_BLK1_LAST) { + // We use this proxy for making sP1 visible to the async proxy + // We skip it if IS_BLK1_LAST, since in that case we have already put a fence + cutlass::arch::fence_view_async_shared(); + } + } + + // Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0 + NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready); + + //warpgroup_cooperative_pv_gemm_remoteP(sP0, sV0R, rO1, idx_in_warpgroup); // replace + load_sP_to_rPb(sP0_ptr, rP1b, idx_in_warpgroup); + warpgroup_cooperative_pv_gemm_localP(rP1b, sV0R, rO1, idx_in_warpgroup); + if constexpr (!IS_BLK0_LAST) { + NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued); + } + + // Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { + cute::warpgroup_wait<1>(); + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup); + } + + // Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) { + cute::warpgroup_wait<0>(); + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup); + } + + if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) { + // Issue rP1 = sQ @ sK1, wait + warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, sVt1_ptr, barriers_K1, cur_phase_K1, idx_in_warpgroup, NamedBarriers::sV1ZeroReady, seqlen_k - (block_idx + 3) * T::PAGE_BLOCK_SIZE); + } + + // We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise + // nvcc cannot correctly analyse the loop, and will think that we are using accumulator + // registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized. + // This is also the reason why we put QK^T here, instead of the first operation in the loop + cute::warpgroup_wait<0>(); +} + +// A helper function for determining the length of the causal mask for one q token +CUTLASS_DEVICE int get_mask_len(const Flash_fwd_mla_params ¶ms, int m_block_idx, int local_seq_q_idx) { + int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx; + if (global_seq_q_idx < params.q_seq_per_hk) { + int s_q_idx = global_seq_q_idx / params.q_head_per_hk; + return params.s_q - s_q_idx - 1; + } else { + // Out-of-bound request, regard as no masks + return 0; + } +} + +template +__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) { + // grid shape: [ + // num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))), + // num_kv_heads, + // num_sm_parts + // ] + // An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx). + // If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx]) + // For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file). + + const int m_block_idx = blockIdx.x; + const int k_head_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int warpgroup_idx = threadIdx.x / 128; + const int idx_in_warpgroup = threadIdx.x % 128; + + // Define shared tensors + extern __shared__ char wksp_buf[]; + using SharedMemoryPlan = typename T::SharedMemoryPlan; + SharedMemoryPlan &plan = *reinterpret_cast(wksp_buf); + + typename T::InputT* sVt0_ptr = reinterpret_cast(plan.smem_vt0); + typename T::InputT* sVt1_ptr = reinterpret_cast(plan.smem_vt1); + + Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ), (typename T::SmemLayoutQ){}); + Tensor sK0 = make_tensor(make_smem_ptr(reinterpret_cast(plan.smem_sK0)), (typename T::SmemLayoutK){}); + Tensor sK1 = make_tensor(make_smem_ptr(reinterpret_cast(plan.smem_sK1)), (typename T::SmemLayoutK){}); + //Tensor sP0 = make_tensor(make_smem_ptr(reinterpret_cast(plan.smem_sP0.data())), (typename T::SmemLayoutP0){}); + //Tensor sP1 = make_tensor(make_smem_ptr(reinterpret_cast(plan.smem_sP1.data())), (typename T::SmemLayoutP0){}); + Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int{})); + + using Fp8Trans = SmemTransposeFp8_64x64; + auto sV0 = [&]{ + return make_tensor(make_smem_ptr(sVt0_ptr), + typename Fp8Trans::SmemLayoutVt{}); + }(); + auto sV1 = [&]{ + return make_tensor(make_smem_ptr(sVt1_ptr), + typename Fp8Trans::SmemLayoutVt{}); + }(); + + Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{})); + Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int{})); + Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int{})); + char* sO_addr = (char*)plan.smem_sK0; // Overlap with sK0 sK1 sV0 sV1 + // Prefetch TMA descriptors + if (threadIdx.x == 0) { + cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor()); + } + + // Define TMA stuffs + Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _); + TMABarrier* barriers_K0 = plan.barriers_K0; + TMABarrier* barriers_K1 = plan.barriers_K1; + TMABarrier* barrier_Q = &(plan.barrier_Q); + + // Initialize TMA barriers + if (threadIdx.x == 0) { + barrier_Q->init(1); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 9; ++i) { + barriers_K0[i].init(1); + barriers_K1[i].init(1); + } + cutlass::arch::fence_view_async_shared(); + } + __syncthreads(); + bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0; + + // Programmatic Dependent Launch: Wait for the previous kernel to finish + cudaGridDependencySynchronize(); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + // We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race. + int4 tile_scheduler_metadata = *(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + + if (begin_idx >= params.b) return; + int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); + + // Copy the first Q + launch_q_copy(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q); + + #pragma unroll 1 + for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) { + constexpr int kBlockN = T::PAGE_BLOCK_SIZE; + const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0; + int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); + const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0; + int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN); + + int rRightBorderForQSeq[2]; + if (params.is_causal) { + // The causal mask looks like: + // XXXX + // XXXX + // ... + // XXXX + // XXX + // XXX + // ... + // XXX + // XX + // XX + // ... + // XX + // Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation. + // Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks + // NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling + int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1); + end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN); + + CUTLASS_PRAGMA_UNROLL + for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) { + int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup); + rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE); + } + } else { + rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k; + } + + // Define global tensors + typename T::OutputT* o_ptr = (typename T::OutputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride; // (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1) + float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) + int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride; // (/) : (1) + + Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout( + Shape, Int>{}, + make_stride(params.o_row_stride, _1{}) + )); + Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout< + Shape>, + Stride<_1> + >{}); + + // Copy K0 and K1 + launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x); + if (start_block_idx+1 < end_block_idx) { + launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x); + } + + Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape, Int>{}); // ((2, 2, 32), 1, 1) + float rL[2]; + rL[0] = rL[1] = 0.0f; + + // Clear buffers + cute::clear(rO); + if (threadIdx.x < size(sM)) { + sM[threadIdx.x] = MAX_INIT_VAL_SM; + } + + // Wait for Q + barrier_Q->wait(cur_phase_Q); + cur_phase_Q ^= 1; + + if (warpgroup_idx == 0) { + // Warpgroup 0 + Tensor rP0 = make_tensor((typename T::rP0Layout){}); + + // NOTE We don't use the pipelined version of Q K^T here since it leads + // to a slow-down (or even register spilling, thanks to the great NVCC) + // Wait for K0 + auto sK0_ptr = sK0.data().get().get(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 9; ++i) { + if (idx_in_warpgroup == 0) + barriers_K0[i].arrive_and_expect_tx(64*64*sizeof(typename T::InputT)); + barriers_K0[i].wait(cur_phase_K0); + if(i!=8) { + int valid_window_size = seqlen_k - start_block_idx * T::PAGE_BLOCK_SIZE; + if(valid_window_size>0 && valid_window_size(sK0_ptr + i*64*64, valid_window_size, idx_in_warpgroup); + cutlass::arch::fence_view_async_shared(); + NamedBarrier::arrive_and_wait(128, NamedBarriers::sPreV0ZeroReady); + } + fp8_transpose_v(sK0_ptr, sVt0_ptr, i); // transpose sK + cutlass::arch::fence_view_async_shared(); + } + } + cur_phase_K0 ^= 1; + // Issue P0 = Q @ K0^T, wait + warpgroup_cooperative_qkt_gemm_no_pipeline(sQ, sK0, rP0, idx_in_warpgroup); + // We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0 + NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized); + cute::warpgroup_wait<0>(); + + #define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \ + wg0_subroutine( \ + tma_gK, sQ, sK0, sK1, plan.smem_sP0, plan.smem_sP1, sM, sScale0, sScale1, \ + rP0, rO, sV0, sV1, sVt0_ptr, rL, rRightBorderForQSeq, \ + barriers_K0, barriers_K1, cur_phase_K0, \ + tma_params, params, \ + block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ + ); + + int block_idx = start_block_idx; + #pragma unroll 1 + for (; block_idx < end_block_idx-2; block_idx += 2) { + LAUNCH_WG0_SUBROUTINE(false, false); + } + + if (block_idx+1 < end_block_idx) { + LAUNCH_WG0_SUBROUTINE(false, true); + } else if (block_idx < end_block_idx) { + LAUNCH_WG0_SUBROUTINE(true, false); + } + + } else { + // Warpgroup 1 + Tensor rP1 = make_tensor((typename T::rP0Layout){}); + + if (start_block_idx+1 < end_block_idx) { + // Issue rP1 = sQ @ sK1, wait + warpgroup_cooperative_qkt_gemm(sQ, sK1, rP1, sVt1_ptr, barriers_K1, cur_phase_K1, idx_in_warpgroup, NamedBarriers::sPreV1ZeroReady, seqlen_k - (start_block_idx+1) * T::PAGE_BLOCK_SIZE); + cute::warpgroup_wait<0>(); + } + + #define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \ + wg1_subroutine( \ + tma_gK, sQ, sK0, sK1, plan.smem_sP0, plan.smem_sP1, sM, sScale0, sScale1, \ + rP1, rO, sV0, sV1, sVt1_ptr, rL, rRightBorderForQSeq, \ + barriers_K0, barriers_K1, cur_phase_K1, \ + tma_params, params, \ + block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \ + ); + + int block_idx = start_block_idx; + #pragma unroll 1 + for (; block_idx < end_block_idx-3; block_idx += 2) { + LAUNCH_WG1_SUBROUTINE(false, false, false); + } + + if (block_idx+2 < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(false, false, true); + block_idx += 2; + LAUNCH_WG1_SUBROUTINE(true, false, false); + } else if (block_idx+1 < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(false, true, false); + } else if (block_idx < end_block_idx) { + LAUNCH_WG1_SUBROUTINE(true, false, false); + } + } + + // Reduce rL across threads within the same warp + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1); + rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1); + rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2); + + // Reduce rL across warpgroups + int my_row = get_AorC_row_idx(0, idx_in_warpgroup); + if (idx_in_warpgroup%4 == 0) { + sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0]; + sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1]; + } + __syncthreads(); + if (warpgroup_idx == 0) { + rL[0] += sL_reduction_wksp[my_row + 64]; + rL[1] += sL_reduction_wksp[my_row + 8 + 64]; + } else { + if (idx_in_warpgroup%4 == 0) { + sL_reduction_wksp[my_row] += rL[0]; + sL_reduction_wksp[my_row + 8] += rL[1]; + } + __syncwarp(); + rL[0] = sL_reduction_wksp[my_row]; + rL[1] = sL_reduction_wksp[my_row+8]; + } + + // Prune out when rL is 0.0f or NaN + // rL may be 0.0f if there are large values (~10^12) in QK^T, which leads + // to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error. + // When this happens, we set rL to 1.0f. This aligns with the old version + // of the MLA kernel. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) + rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i]; + + // Copy Q for the next batch + if (batch_idx+1 <= end_idx) { + launch_q_copy(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q); + } else { + // Allow the next kernel (the combine kernel) to launch + // The next kernel MUST be the combine kernel + cudaTriggerProgrammaticLaunchCompletion(); + } + + int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M); + if (is_no_split) { + store_o(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL_reduction_wksp[i]; + gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E; + } + + cute::tma_store_wait<0>(); + } else { + // Don't use __ldg because of PDL and instruction reordering + int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx; + float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) + float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1) + Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout< + Shape, Int>, + Stride, _1> + >{}); + Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout< + Shape>, + Stride<_1> + >{}); + store_o(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup); + + int i = threadIdx.x; + if (i < num_valid_seq_q) { + float cur_L = sL_reduction_wksp[i]; + gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i); + } + + cute::tma_store_wait<0>(); + } + + if (batch_idx != end_idx) + __syncthreads(); + } +} + + +template +void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + using TYPE = TraitsFP8; + auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); + using AtomQ = decltype(get_smem_layoutK()); + auto tma_Q = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((InputT*)params.q_ptr), + make_layout( + shape_Q, + make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride) + ) + ), + tile_to_shape( + AtomQ{}, + Shape, Int>{} + ) + ); + + auto shape_K = make_shape(Int{}, Int{}, params.h_k, params.num_blocks); + using AtomK = decltype(get_smem_layoutK()); + auto tma_K = cute::make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor( + make_gmem_ptr((InputT*)params.k_ptr), + make_layout( + shape_K, + make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride) + ) + ), + tile_to_shape( + AtomK{}, + Layout< + Shape, Int<64>>, + Stride, _1> + >{} + ) + ); + + auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b); + using AtomO = decltype(get_smem_layoutK()); + auto tma_O = cute::make_tma_copy( + SM90_TMA_STORE{}, + make_tensor( + make_gmem_ptr((typename TYPE::OutputT*)params.o_ptr), + make_layout( + shape_O, + make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride) + ) + ), + tile_to_shape( + AtomO{}, + Shape, Int>{} + ) + ); + TmaParams tma_params = { + shape_Q, tma_Q, + shape_K, tma_K, + shape_O, tma_O + }; + auto mla_kernel = &flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(typename TYPE::SharedMemoryPlan); + CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) + const int num_m_block = cute::ceil_div(params.q_seq_per_hk, TYPE::BLOCK_SIZE_M); + cudaLaunchAttribute mla_kernel_attributes[1]; + mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; + cudaLaunchConfig_t mla_kernel_config = { + dim3(num_m_block, params.h_k, params.num_sm_parts), + dim3(TYPE::NUM_THREADS, 1, 1), + smem_size, + stream, + mla_kernel_attributes, + 1 + }; + cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +#ifndef FLASH_MLA_DISABLE_FP16 +template void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +#endif + diff --git a/csrc/sm90/kernels/traits.h b/csrc/sm90/kernels/traits.h index 5f915a6..574683a 100644 --- a/csrc/sm90/kernels/traits.h +++ b/csrc/sm90/kernels/traits.h @@ -10,6 +10,21 @@ using TMABarrier = cutlass::arch::ClusterTransactionBarrier; using namespace cute; +using TMABarrier = cutlass::arch::ClusterTransactionBarrier; +template +constexpr auto get_smem_layoutK() { + constexpr int b0 = sizeof(T)*D, b1=sizeof(T)*D2; + if constexpr(M==GMMA::Major::K){ + if constexpr(b0%128==0 && b1%128==0) return GMMA::Layout_K_SW128_Atom{}; + else if constexpr(b0%64==0 && b1%64==0) return GMMA::Layout_K_SW64_Atom{}; + else return GMMA::Layout_K_SW32_Atom{}; + } else { + if constexpr(b0%128==0 && b1%128==0) return GMMA::Layout_MN_SW128_Atom{}; + else if constexpr(b0%64==0 && b1%64==0) return GMMA::Layout_MN_SW64_Atom{}; + else return GMMA::Layout_MN_SW32_Atom{}; + } +} + template struct Traits { using InputT = InputT_; @@ -84,6 +99,90 @@ struct Traits { }; +template +struct TraitsFP8 { + using InputT = T_; + using OutputT = OutputT_; + static constexpr bool IsFp8 = std::is_same_v; + + static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M; // 64 + static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE; // 64 + static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K; // 576 + static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V; // 512 + static constexpr int NUM_THREADS = 256; + + static_assert( std::is_same_v ); + + static_assert( std::is_same_v || + std::is_same_v ); + + using TiledMMA_QK_sQ = decltype(make_tiled_mma( + GMMA::ss_op_selector,Int,Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout>{})); + + using TiledMMA_PV_LocalP = decltype(make_tiled_mma( + GMMA::rs_op_selector,Int,Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout>{})); + + + using TiledMMA_PV_RemoteP = decltype(make_tiled_mma( + GMMA::ss_op_selector,Int,Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout>{})); + + using SmemLayoutQ = decltype(tile_to_shape( + get_smem_layoutK(), + Shape,Int>{})); + + using SmemLayoutK = decltype(tile_to_shape( + get_smem_layoutK(), + Shape,Int>{})); + + using SmemLayoutV = SmemLayoutK; + using SmemLayoutV_Trans = decltype( + composition(SmemLayoutK{}, + make_layout(Shape,Int>{},GenRowMajor{}))); + + using SmemLayoutP0 = decltype(tile_to_shape( + get_smem_layoutK(), + Shape, Int>{} + )); + + using rP0Layout = decltype(layout(partition_fragment_C( + TiledMMA_QK_sQ{}, + Shape, Int>{} + ))); + + using SmemLayoutPi = Layout>, + Stride< _2,Stride<_1,_256>>>; + + using RegLayout = Layout,_1>,_1,_1>, + Stride,_0>,_0,_0>>; + + struct SharedMemoryPlan { + alignas(16) InputT smem_sQ[BLOCK_SIZE_M * HEAD_DIM_K]; + alignas(16) InputT smem_sK0[PAGE_BLOCK_SIZE * HEAD_DIM_K]; // overlap Sout + alignas(16) InputT smem_sK1[PAGE_BLOCK_SIZE * HEAD_DIM_K]; // overlap Sout + alignas(16) InputT smem_vt0[PAGE_BLOCK_SIZE * HEAD_DIM_V]; // overlap Sout + alignas(16) InputT smem_vt1[PAGE_BLOCK_SIZE * HEAD_DIM_V]; // overlap Sout + alignas(16) uint16_t smem_sP0[64 * 32]; + alignas(16) uint16_t smem_sP1[64 * 32]; + + cute::array_aligned smem_sM; + cute::array_aligned sL_reduction_wksp; + cute::array_aligned smem_sScale0; + cute::array_aligned smem_sScale1; + TMABarrier barriers_K0[HEAD_DIM_K/64]; + TMABarrier barriers_K1[HEAD_DIM_K/64]; + TMABarrier barrier_Q; + }; +}; + template< typename ShapeQ, typename TMA_Q, typename ShapeK, typename TMA_K, @@ -104,4 +203,8 @@ enum NamedBarriers : int { sP0Ready = 2, rO1sP0sV0RIssued = 3, sMInitialized = 4, + sPreV0ZeroReady = 5, + sPreV1ZeroReady = 6, + sV0ZeroReady = 7, + sV1ZeroReady = 8 }; diff --git a/csrc/sm90/kernels/utils.h b/csrc/sm90/kernels/utils.h index ae9d0fc..f40e65d 100644 --- a/csrc/sm90/kernels/utils.h +++ b/csrc/sm90/kernels/utils.h @@ -1,5 +1,96 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include "fp8_transpose_v.h" + +using namespace cute; + +// Fill out-of-bound V with 0.0 +// We must fill it since it may contain NaN, which may propagate to the final result +template< + typename T +> +CUTLASS_DEVICE void fill_oob_KV( + typename T::InputT* sV_ptr, // ptr of tensor(tile_to_shape(Shape,Int>{})) + int valid_window_size, + int idx_in_warpgroup +) { + Tensor sV_int64 = make_tensor( + make_smem_ptr((int64_t*)(sV_ptr)), + tile_to_shape( + GMMA::Layout_K_SW64_Atom{}, + Shape, Int<8>>{} // (64, 64/(64/8)) + ) + ); + valid_window_size = max(valid_window_size, 0); + int head_dim_size = 8; // 128%head_dim_size == 0 should holds + for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<0>(sV_int64); token_idx += (128/head_dim_size)) { + sV_int64(token_idx, idx_in_warpgroup%head_dim_size) = 0; + } +} + + +template +CUTLASS_DEVICE void fp8_transpose_v(typename T::InputT* sK_ptr, + typename T::InputT* sVt_ptr, + int tile_id) +{ + // every tile: (64, 64) + using Fp8Trans = SmemTransposeFp8_64x64; + Fp8Trans trans; + Tensor src = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(sK_ptr), + typename Fp8Trans::SmemLayoutTransposeV{})); + + Tensor dst = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(sVt_ptr), + typename Fp8Trans::SmemLayoutTransposeVt{})); + trans.transpose( + flatten(src(_, _0{}, tile_id)), + flatten(dst(_, _0{}, tile_id))); +} + + +template +CUTLASS_DEVICE void permute_Cregs_128_to_64(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} + + #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 9c669ba..4619ce2 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -35,6 +35,8 @@ def flash_mla_with_kvcache_sm90( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -64,6 +66,8 @@ def flash_mla_with_kvcache_sm90( causal, tile_scheduler_metadata, num_splits, + descale_q, + descale_k, ) return out, softmax_lse @@ -315,13 +319,15 @@ def flash_mla_with_kvcache( num_splits: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: capability = torch.cuda.get_device_capability(q.device.index) if capability == (9, 0): return flash_mla_with_kvcache_sm90( q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata, num_splits, - softmax_scale, causal, + softmax_scale, causal, descale_q, descale_k, ) elif capability == (10, 0): raise ValueError(f"Unsupported device capability: {capability}") diff --git a/setup.py b/setup.py index 58cf7b2..1b5a9bf 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ def get_features_args(): "csrc/sm90/kernels/get_mla_metadata.cu", "csrc/sm90/kernels/mla_combine.cu", "csrc/sm90/kernels/splitkv_mla.cu", + "csrc/sm90/kernels/splitkv_mla_fp8.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), diff --git a/tests/test_flash_mla_sm90.py b/tests/test_flash_mla_sm90.py index 67c9d93..3556702 100644 --- a/tests/test_flash_mla_sm90.py +++ b/tests/test_flash_mla_sm90.py @@ -28,21 +28,24 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): return attn_weight @ value, lse -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 + if use_fp8: + assert cos_diff < 3e-2 + else: + assert cos_diff < 1e-5 @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype): print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" ) - + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -69,6 +72,27 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): cache_seqlens, s_q * h_q // h_kv, h_kv ) + init_dtype = q.dtype + def prepare_fp8_input(): + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None + if use_fp8: + nonlocal q, blocked_k, blocked_v + fp8_dtype = torch.float8_e4m3fn + # descale_q = torch.ones((1), dtype=torch.float32, device="cpu") // cpu scalar faster ,but need change sglang api used + # descale_k = torch.ones((1), dtype=torch.float32, device="cpu") // cpu scalar faster ,but need change sglang api used + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k + + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() + if use_fp8: + q = q_fp8 + blocked_k = blocked_k_fp8 + blocked_v = blocked_v_fp8 + def flash_mla(): return flash_mla_with_kvcache( q, @@ -79,18 +103,23 @@ def flash_mla(): tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), h_q=h_q, h_kv=h_kv, is_causal=causal, @@ -101,7 +130,7 @@ def ref_mla(): out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) @@ -116,7 +145,8 @@ def ref_mla(): def main(torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(torch_dtype) + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) @@ -131,7 +161,7 @@ def main(torch_dtype): for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype) if __name__ == "__main__": @@ -139,9 +169,9 @@ def main(torch_dtype): parser.add_argument( "--dtype", type=str, - choices=["bf16", "fp16"], + choices=["bf16", "fp16", "float8_e4m3fn"], default="bf16", - help="Data type to use for testing (bf16 or fp16)", + help="Data type to use for testing (bf16 or fp16 or float8_e4m3fn)", ) args = parser.parse_args() @@ -149,5 +179,7 @@ def main(torch_dtype): torch_dtype = torch.bfloat16 if args.dtype == "fp16": torch_dtype = torch.float16 + elif args.dtype == "float8_e4m3fn": + torch_dtype = torch.float8_e4m3fn main(torch_dtype)