Skip to content

Commit 4d263c0

Browse files
committed
fix format error
Signed-off-by: Jinzhen Lin <[email protected]>
1 parent 57abfd4 commit 4d263c0

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "../cuda_compat.h"
99
#include "../dispatch_utils.h"
1010

11-
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
11+
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
1212

1313
namespace vllm {
1414
namespace moe {
@@ -33,7 +33,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
3333

3434
extern __shared__ int32_t shared_mem[];
3535
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36-
token_cnts_t* tokens_cnts = (token_cnts_t *) (shared_mem + blockDim.x + 1);
36+
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
3737

3838
for (int i = 0; i < num_experts; ++i) {
3939
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@@ -70,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
7070
block_size) *
7171
block_size;
7272
}
73-
*total_tokens_post_pad = (int32_t) cumsum[num_experts];
73+
*total_tokens_post_pad = (int32_t)cumsum[num_experts];
7474
}
7575

7676
__syncthreads();
@@ -222,20 +222,21 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
222222

223223
int device_max_shared_mem;
224224
auto dev = topk_ids.get_device();
225-
cudaDeviceGetAttribute(&device_max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
225+
cudaDeviceGetAttribute(&device_max_shared_mem,
226+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
226227

227228
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
228229
const int32_t shared_mem_i32 =
229-
((num_thread + 1) * num_experts + (num_experts + 1)) *
230-
sizeof(int32_t);
230+
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
231231
const int32_t shared_mem_i16 =
232-
((num_thread + 1) * num_experts) * sizeof(uint16_t) + (num_experts + 1) *
233-
sizeof(int32_t);
232+
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
233+
(num_experts + 1) * sizeof(int32_t);
234234

235235
bool use_global_memory = false, use_i16 = false;
236236
if (shared_mem_i16 > device_max_shared_mem) {
237237
use_global_memory = true;
238-
} else if (shared_mem_i32 > device_max_shared_mem && topk_ids.numel() <= 65535) {
238+
} else if (shared_mem_i32 > device_max_shared_mem &&
239+
topk_ids.numel() <= 65535) {
239240
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
240241
// element value of token_cnts would also smaller than 65535,
241242
// so we can use uint16 as dtype of token_cnts
@@ -249,8 +250,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
249250
// tensors
250251
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
251252

252-
auto options_int =
253-
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
253+
auto options_int = torch::TensorOptions()
254+
.dtype(torch::kInt)
255+
.device(topk_ids.device());
254256
torch::Tensor token_cnts_buffer =
255257
torch::empty({(num_experts + 1) * num_experts}, options_int);
256258
torch::Tensor cumsum_buffer =
@@ -270,7 +272,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
270272
VLLM_DISPATCH_INTEGRAL_TYPES(
271273
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
272274
// set dynamic shared mem
273-
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
275+
auto kernel =
276+
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
274277
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
275278
(void*)kernel, shared_mem_i16));
276279
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
@@ -282,8 +285,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
282285
});
283286
} else {
284287
VLLM_DISPATCH_INTEGRAL_TYPES(
285-
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
286-
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
288+
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
289+
auto kernel =
290+
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
287291
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
288292
(void*)kernel, shared_mem_i32));
289293
kernel<<<1, num_thread, shared_mem_i32, stream>>>(

0 commit comments

Comments
 (0)