Skip to content

Commit 9613ea7

Browse files
committed
0811
1 parent 9edee0c commit 9613ea7

File tree

11 files changed

+924
-249
lines changed

11 files changed

+924
-249
lines changed

csrc/cutlass

Submodule cutlass deleted from e94e888

csrc/cutlass

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
//home/wangzhe/deepseek/tune_block_shape/DeepGEMM/third-party/cutlass

csrc/flash_api.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@
2020
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
2121
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
2222

23+
inline float get_scalar_f32_cpu_only(c10::optional<const at::Tensor> & scale,
24+
const char* name = "dequant scale") {
25+
TORCH_CHECK(scale.has_value(),
26+
name, " is None (optional has no value)");
27+
const at::Tensor& t = *scale;
28+
TORCH_CHECK(!t.device().is_cuda(),
29+
"descale_q / descale_k must be on CPU, but got ",
30+
t.device().type(), " device");
31+
TORCH_CHECK(t.scalar_type() == torch::kFloat32,
32+
"descale_q / descale_k must be float32, but got ",
33+
t.scalar_type());
34+
TORCH_CHECK(t.numel() == 1,
35+
"descale_q / descale_k must be a scalar, but got ",
36+
t.numel(), " elements");
37+
return t.item<float>();
38+
}
39+
2340
std::vector<at::Tensor>
2441
get_mla_metadata(
2542
at::Tensor &seqlens_k,
@@ -68,16 +85,19 @@ mha_fwd_kvcache_mla(
6885
const float softmax_scale,
6986
bool is_causal,
7087
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
71-
const at::Tensor &num_splits // batch_size + 1
88+
const at::Tensor &num_splits, // batch_size + 1
89+
c10::optional<const at::Tensor> &descale_q, // batch_size
90+
c10::optional<const at::Tensor> &descale_k // batch_size
7291
) {
7392
// Check the architecture
7493
auto dprops = at::cuda::getCurrentDeviceProperties();
7594
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
7695
TORCH_CHECK(is_sm90);
7796

7897
// Check data types
79-
auto q_dtype = q.dtype();
80-
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
98+
auto q_dtype = q.scalar_type();
99+
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf||
100+
q_dtype == torch::kFloat8_e4m3fn, "Unsupported dtype for query tensor");
81101
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
82102
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
83103
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
@@ -106,7 +126,7 @@ mha_fwd_kvcache_mla(
106126
const int num_heads_q = sizes[2];
107127
const int head_size_k = sizes[3];
108128
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
109-
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
129+
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 512 is supported");
110130

111131
const int max_num_blocks_per_seq = block_table.size(1);
112132
const int num_blocks = kcache.size(0);
@@ -133,7 +153,9 @@ mha_fwd_kvcache_mla(
133153
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
134154

135155
auto opts = q.options();
136-
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
156+
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype;
157+
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type));
158+
137159
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
138160
CHECK_CONTIGUOUS(softmax_lse);
139161

@@ -152,6 +174,11 @@ mha_fwd_kvcache_mla(
152174
params.d_v = head_size_v;
153175
params.scale_softmax = softmax_scale;
154176
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
177+
if (q_dtype == torch::kFloat8_e4m3fn) {
178+
params.descale_q = get_scalar_f32_cpu_only(descale_q);
179+
params.descale_k = get_scalar_f32_cpu_only(descale_q);
180+
}
181+
155182
// Set the pointers and strides.
156183
params.q_ptr = q.data_ptr();
157184
params.k_ptr = kcache.data_ptr();
@@ -188,15 +215,19 @@ mha_fwd_kvcache_mla(
188215
auto stream = at::cuda::getCurrentCUDAStream().stream();
189216
TORCH_CHECK(head_size_k == 576);
190217
if (q_dtype == torch::kBFloat16) {
191-
run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
192-
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
218+
TORCH_CHECK(false, "Unsupported tensor dtype for query");
219+
//run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
220+
//run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
193221
} else if (q_dtype == torch::kHalf) {
222+
TORCH_CHECK(false, "Unsupported tensor dtype for query");
194223
#ifdef FLASH_MLA_DISABLE_FP16
195224
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
196225
#else
197-
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
198-
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
226+
//run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
227+
//run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
199228
#endif
229+
} else if (q_dtype == torch::kFloat8_e4m3fn) {
230+
run_flash_splitkv_mla_kernel<cutlass::float_e4m3_t, cutlass::bfloat16_t>(params, stream);
200231
} else {
201232
TORCH_CHECK(false, "Unsupported tensor dtype for query");
202233
}

csrc/kernels/fp8_transpose_v.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/**
2+
* ref to Fa3's SmemTranspose64x64:
3+
* https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26
4+
*/
5+
6+
#pragma once
7+
using namespace cute;
8+
9+
template <int kBlockN, int kHeadDim>
10+
struct SmemTransposeFp8_64x64 {
11+
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
12+
13+
using Element = cutlass::float_e4m3_t;
14+
using TransposeShapeAtomV = Shape<_64, _64>;
15+
using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
16+
using SmemLayoutV =
17+
decltype(tile_to_shape(SmemLayoutAtomV{},
18+
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
19+
20+
// for fp8 in-kernel transpose -- src layout
21+
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
22+
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
23+
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{})));
24+
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
25+
26+
// For fp8, this is the memory transpose.
27+
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
28+
using SmemLayoutVt =
29+
decltype(tile_to_shape(SmemLayoutAtomVt{},
30+
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
31+
32+
// for fp8 in-kernel transpose -- dst layout
33+
using SmemLayoutVtTrans = decltype(composition(
34+
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
35+
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
36+
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
37+
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
38+
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
39+
40+
41+
using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
42+
using ldsm_value_shape = Shape<_2, _8, _2, _1>;
43+
using ldsm_value_stride = Stride<_2, _4, _1, _0>;
44+
using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
45+
Layout<ldsm_value_shape, ldsm_value_stride>{}));
46+
TiledCopyLDSM tiled_copy_ldsm;
47+
48+
using stsm_thread_shape = Shape<_4, _1, _8, _4>;
49+
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
50+
using stsm_value_shape = Shape<_4, _4, _2, _1>;
51+
using stsm_value_stride = Stride<_1, _8, _4, _0>;
52+
53+
using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<stsm_thread_shape>{},
54+
Layout<stsm_value_shape, stsm_value_stride>{}));
55+
TiledCopySTSM tiled_copy_stsm;
56+
57+
template <class SmemTensor, class SmemTensorOut>
58+
CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) {
59+
using namespace cute;
60+
61+
auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
62+
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
63+
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
64+
65+
auto tXsX = thr_copy_ldsm.partition_S(s_in);
66+
auto tXrX = make_tensor<Element>(shape(tXsX));
67+
auto tXsX_out = thr_copy_stsm.partition_D(s_out);
68+
69+
cute::copy(tiled_copy_ldsm, tXsX, tXrX);
70+
71+
auto data = tXrX.data();
72+
CUTLASS_PRAGMA_UNROLL
73+
for (int n = 0; n < size(tXrX); n += 8) {
74+
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
75+
auto upper = data_32bit[0];
76+
auto lower = data_32bit[1];
77+
data_32bit[0] = __byte_perm(upper, lower, 0x6420);
78+
data_32bit[1] = __byte_perm(upper, lower, 0x7531);
79+
}
80+
81+
cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
82+
}
83+
};
84+

csrc/kernels/params.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct Flash_fwd_mla_params {
1414
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
1515
bool is_causal;
1616
float scale_softmax, scale_softmax_log2;
17+
float descale_q, descale_k;
1718

1819
void *__restrict__ q_ptr;
1920
void *__restrict__ k_ptr;

0 commit comments

Comments
 (0)