20
20
#define CHECK_SHAPE (x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ " )" )
21
21
#define CHECK_CONTIGUOUS (x ) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous" )
22
22
23
+ inline float get_scalar_f32_cpu_only (c10::optional<const at::Tensor> & scale,
24
+ const char * name = " dequant scale" ) {
25
+ TORCH_CHECK (scale.has_value (),
26
+ name, " is None (optional has no value)" );
27
+ const at::Tensor& t = *scale;
28
+ TORCH_CHECK (!t.device ().is_cuda (),
29
+ " descale_q / descale_k must be on CPU, but got " ,
30
+ t.device ().type (), " device" );
31
+ TORCH_CHECK (t.scalar_type () == torch::kFloat32 ,
32
+ " descale_q / descale_k must be float32, but got " ,
33
+ t.scalar_type ());
34
+ TORCH_CHECK (t.numel () == 1 ,
35
+ " descale_q / descale_k must be a scalar, but got " ,
36
+ t.numel (), " elements" );
37
+ return t.item <float >();
38
+ }
39
+
23
40
std::vector<at::Tensor>
24
41
get_mla_metadata (
25
42
at::Tensor &seqlens_k,
@@ -68,16 +85,19 @@ mha_fwd_kvcache_mla(
68
85
const float softmax_scale,
69
86
bool is_causal,
70
87
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
71
- const at::Tensor &num_splits // batch_size + 1
88
+ const at::Tensor &num_splits, // batch_size + 1
89
+ c10::optional<const at::Tensor> &descale_q, // batch_size
90
+ c10::optional<const at::Tensor> &descale_k // batch_size
72
91
) {
73
92
// Check the architecture
74
93
auto dprops = at::cuda::getCurrentDeviceProperties ();
75
94
bool is_sm90 = dprops->major == 9 && dprops->minor == 0 ;
76
95
TORCH_CHECK (is_sm90);
77
96
78
97
// Check data types
79
- auto q_dtype = q.dtype ();
80
- TORCH_CHECK (q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf );
98
+ auto q_dtype = q.scalar_type ();
99
+ TORCH_CHECK (q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf ||
100
+ q_dtype == torch::kFloat8_e4m3fn , " Unsupported dtype for query tensor" );
81
101
TORCH_CHECK (kcache.dtype () == q_dtype, " query and key must have the same dtype" );
82
102
TORCH_CHECK (seqlens_k.dtype () == torch::kInt32 , " seqlens_k must have dtype int32" );
83
103
TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
@@ -106,7 +126,7 @@ mha_fwd_kvcache_mla(
106
126
const int num_heads_q = sizes[2 ];
107
127
const int head_size_k = sizes[3 ];
108
128
TORCH_CHECK (head_size_k == 576 , " Only head_size_k == 576 is supported" );
109
- TORCH_CHECK (head_size_v == 512 , " Only head_size_v == 576 is supported" );
129
+ TORCH_CHECK (head_size_v == 512 , " Only head_size_v == 512 is supported" );
110
130
111
131
const int max_num_blocks_per_seq = block_table.size (1 );
112
132
const int num_blocks = kcache.size (0 );
@@ -133,7 +153,9 @@ mha_fwd_kvcache_mla(
133
153
at::cuda::CUDAGuard device_guard{(char )q.get_device ()};
134
154
135
155
auto opts = q.options ();
136
- at::Tensor out = torch::empty ({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
156
+ auto out_type = (q_dtype == torch::kFloat8_e4m3fn ) ? torch::kBFloat16 : q_dtype;
157
+ at::Tensor out = torch::empty ({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype (out_type));
158
+
137
159
at::Tensor softmax_lse = torch::empty ({batch_size, num_heads, q_seq_per_hk}, opts.dtype (at::kFloat ));
138
160
CHECK_CONTIGUOUS (softmax_lse);
139
161
@@ -152,6 +174,11 @@ mha_fwd_kvcache_mla(
152
174
params.d_v = head_size_v;
153
175
params.scale_softmax = softmax_scale;
154
176
params.scale_softmax_log2 = float (softmax_scale * M_LOG2E);
177
+ if (q_dtype == torch::kFloat8_e4m3fn ) {
178
+ params.descale_q = get_scalar_f32_cpu_only (descale_q);
179
+ params.descale_k = get_scalar_f32_cpu_only (descale_q);
180
+ }
181
+
155
182
// Set the pointers and strides.
156
183
params.q_ptr = q.data_ptr ();
157
184
params.k_ptr = kcache.data_ptr ();
@@ -188,15 +215,19 @@ mha_fwd_kvcache_mla(
188
215
auto stream = at::cuda::getCurrentCUDAStream ().stream ();
189
216
TORCH_CHECK (head_size_k == 576 );
190
217
if (q_dtype == torch::kBFloat16 ) {
191
- run_flash_splitkv_mla_kernel<cutlass::bfloat16_t >(params, stream);
192
- run_flash_mla_combine_kernel<cutlass::bfloat16_t >(params, stream);
218
+ TORCH_CHECK (false , " Unsupported tensor dtype for query" );
219
+ // run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
220
+ // run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
193
221
} else if (q_dtype == torch::kHalf ) {
222
+ TORCH_CHECK (false , " Unsupported tensor dtype for query" );
194
223
#ifdef FLASH_MLA_DISABLE_FP16
195
224
TORCH_CHECK (false , " FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA." );
196
225
#else
197
- run_flash_splitkv_mla_kernel<cutlass::half_t >(params, stream);
198
- run_flash_mla_combine_kernel<cutlass::half_t >(params, stream);
226
+ // run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
227
+ // run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
199
228
#endif
229
+ } else if (q_dtype == torch::kFloat8_e4m3fn ) {
230
+ run_flash_splitkv_mla_kernel<cutlass::float_e4m3_t , cutlass::bfloat16_t >(params, stream);
200
231
} else {
201
232
TORCH_CHECK (false , " Unsupported tensor dtype for query" );
202
233
}
0 commit comments