8
8
#include " ../cuda_compat.h"
9
9
#include " ../dispatch_utils.h"
10
10
11
- #define CEILDIV (x, y ) (((x) + (y) - 1 ) / (y))
11
+ #define CEILDIV (x, y ) (((x) + (y)- 1 ) / (y))
12
12
13
13
namespace vllm {
14
14
namespace moe {
@@ -33,7 +33,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
33
33
34
34
extern __shared__ int32_t shared_mem[];
35
35
int32_t * cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36
- token_cnts_t * tokens_cnts = (token_cnts_t *) (shared_mem + blockDim .x + 1 );
36
+ token_cnts_t * tokens_cnts = (token_cnts_t *) (shared_mem + blockDim .x + 1 );
37
37
38
38
for (int i = 0 ; i < num_experts; ++i) {
39
39
tokens_cnts[index (num_experts, threadIdx .x + 1 , i)] = 0 ;
@@ -70,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
70
70
block_size) *
71
71
block_size;
72
72
}
73
- *total_tokens_post_pad = (int32_t ) cumsum[num_experts];
73
+ *total_tokens_post_pad = (int32_t )cumsum[num_experts];
74
74
}
75
75
76
76
__syncthreads ();
@@ -222,20 +222,21 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
222
222
223
223
int device_max_shared_mem;
224
224
auto dev = topk_ids.get_device ();
225
- cudaDeviceGetAttribute (&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
225
+ cudaDeviceGetAttribute (&device_max_shared_mem,
226
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
226
227
227
228
const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
228
229
const int32_t shared_mem_i32 =
229
- ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
230
- sizeof (int32_t );
230
+ ((num_thread + 1 ) * num_experts + (num_experts + 1 )) * sizeof (int32_t );
231
231
const int32_t shared_mem_i16 =
232
- ((num_thread + 1 ) * num_experts) * sizeof (uint16_t ) + (num_experts + 1 ) *
233
- sizeof (int32_t );
232
+ ((num_thread + 1 ) * num_experts) * sizeof (uint16_t ) +
233
+ (num_experts + 1 ) * sizeof (int32_t );
234
234
235
235
bool use_global_memory = false , use_i16 = false ;
236
236
if (shared_mem_i16 > device_max_shared_mem) {
237
237
use_global_memory = true ;
238
- } else if (shared_mem_i32 > device_max_shared_mem && topk_ids.numel () <= 65535 ) {
238
+ } else if (shared_mem_i32 > device_max_shared_mem &&
239
+ topk_ids.numel () <= 65535 ) {
239
240
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
240
241
// element value of token_cnts would also smaller than 65535,
241
242
// so we can use uint16 as dtype of token_cnts
@@ -249,8 +250,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
249
250
// tensors
250
251
const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
251
252
252
- auto options_int =
253
- torch::TensorOptions ().dtype (torch::kInt ).device (topk_ids.device ());
253
+ auto options_int = torch::TensorOptions ()
254
+ .dtype (torch::kInt )
255
+ .device (topk_ids.device ());
254
256
torch::Tensor token_cnts_buffer =
255
257
torch::empty ({(num_experts + 1 ) * num_experts}, options_int);
256
258
torch::Tensor cumsum_buffer =
@@ -270,7 +272,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
270
272
VLLM_DISPATCH_INTEGRAL_TYPES (
271
273
topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
272
274
// set dynamic shared mem
273
- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t , uint16_t >;
275
+ auto kernel =
276
+ vllm::moe::moe_align_block_size_kernel<scalar_t , uint16_t >;
274
277
AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
275
278
(void *)kernel, shared_mem_i16));
276
279
kernel<<<1 , num_thread, shared_mem_i16, stream>>> (
@@ -282,8 +285,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
282
285
});
283
286
} else {
284
287
VLLM_DISPATCH_INTEGRAL_TYPES (
285
- topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
286
- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t , int32_t >;
288
+ topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
289
+ auto kernel =
290
+ vllm::moe::moe_align_block_size_kernel<scalar_t , int32_t >;
287
291
AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
288
292
(void *)kernel, shared_mem_i32));
289
293
kernel<<<1 , num_thread, shared_mem_i32, stream>>> (
0 commit comments