@@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
113
113
}
114
114
}
115
115
116
+ // TODO(simon): this is temporarily adapted from
117
+ // https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
118
+ // we did this to unblock Deepseek V3 but there should be a better
119
+ // implementation to manage shared memory.
120
+ template <typename scalar_t >
121
+ __global__ void moe_align_block_size_global_mem_kernel (
122
+ scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
123
+ int32_t * expert_ids, int32_t * total_tokens_post_pad, int32_t num_experts,
124
+ int32_t block_size, size_t numel, int32_t * tokens_cnts, int32_t * cumsum) {
125
+ const size_t tokens_per_thread = CEILDIV (numel, blockDim .x );
126
+ const size_t start_idx = threadIdx .x * tokens_per_thread;
127
+
128
+ for (int i = 0 ; i < num_experts; ++i) {
129
+ tokens_cnts[index (num_experts, threadIdx .x + 1 , i)] = 0 ;
130
+ }
131
+
132
+ /* *
133
+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
134
+ * which counts how many tokens in the token shard of thread_index are
135
+ * assigned to expert expert_index.
136
+ */
137
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
138
+ ++tokens_cnts[index (num_experts, threadIdx .x + 1 , topk_ids[i])];
139
+ }
140
+
141
+ __syncthreads ();
142
+
143
+ // For each expert we accumulate the token counts from the different threads.
144
+ if (threadIdx .x < num_experts) {
145
+ tokens_cnts[index (num_experts, 0 , threadIdx .x )] = 0 ;
146
+ for (int i = 1 ; i <= blockDim .x ; ++i) {
147
+ tokens_cnts[index (num_experts, i, threadIdx .x )] +=
148
+ tokens_cnts[index (num_experts, i - 1 , threadIdx .x )];
149
+ }
150
+ }
151
+
152
+ __syncthreads ();
153
+
154
+ // We accumulate the token counts of all experts in thread 0.
155
+ if (threadIdx .x == 0 ) {
156
+ cumsum[0 ] = 0 ;
157
+ for (int i = 1 ; i <= num_experts; ++i) {
158
+ cumsum[i] = cumsum[i - 1 ] +
159
+ CEILDIV (tokens_cnts[index (num_experts, blockDim .x , i - 1 )],
160
+ block_size) *
161
+ block_size;
162
+ }
163
+ *total_tokens_post_pad = cumsum[num_experts];
164
+ }
165
+
166
+ __syncthreads ();
167
+
168
+ /* *
169
+ * For each expert, each thread processes the tokens of the corresponding
170
+ * blocks and stores the corresponding expert_id for each block.
171
+ */
172
+ if (threadIdx .x < num_experts) {
173
+ for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
174
+ i += block_size) {
175
+ expert_ids[i / block_size] = threadIdx .x ;
176
+ }
177
+ }
178
+
179
+ /* *
180
+ * Each thread processes a token shard, calculating the index of each token
181
+ * after sorting by expert number. Given the example topk_ids =
182
+ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
183
+ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
184
+ * padding value(preset in python).
185
+ */
186
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
187
+ int32_t expert_id = topk_ids[i];
188
+ /* * The cumsum[expert_id] stores the starting index of the tokens that the
189
+ * expert with expert_id needs to process, and
190
+ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
191
+ * processed by the expert with expert_id within the current thread's token
192
+ * shard.
193
+ */
194
+ int32_t rank_post_pad =
195
+ tokens_cnts[index (num_experts, threadIdx .x , expert_id)] +
196
+ cumsum[expert_id];
197
+ sorted_token_ids[rank_post_pad] = i;
198
+ ++tokens_cnts[index (num_experts, threadIdx .x , expert_id)];
199
+ }
200
+ }
201
+
116
202
template <typename scalar_t , int TOPK>
117
203
__global__ void moe_sum_kernel (
118
204
scalar_t * __restrict__ out, // [..., d]
@@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
137
223
torch::Tensor experts_ids,
138
224
torch::Tensor num_tokens_post_pad) {
139
225
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
140
- VLLM_DISPATCH_INTEGRAL_TYPES (
141
- topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
142
- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
143
- // tensors
144
- const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
145
- const int32_t shared_mem =
146
- ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
147
- sizeof (int32_t );
148
-
149
- // set dynamic shared mem
150
- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
151
- AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
152
- (void *)kernel, shared_mem));
153
- kernel<<<1 , num_thread, shared_mem, stream>>> (
154
- topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
155
- experts_ids.data_ptr <int32_t >(),
156
- num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
157
- topk_ids.numel ());
158
- });
226
+
227
+ // If we have very large number of experts, we can no longer use shared
228
+ // memory.
229
+ // TODO(simon): the right solution should be calculating the exact right
230
+ // amount of shared memory and use that. The num_experts >= 256 is just a
231
+ // temporary solution to unblock Deepseek V3.
232
+ if (num_experts >= 256 ) {
233
+ VLLM_DISPATCH_INTEGRAL_TYPES (
234
+ topk_ids.scalar_type (), " moe_align_block_size_global_mem_kernel" , [&] {
235
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236
+ // tensors
237
+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
238
+
239
+ const int32_t mem_tokens_cnts =
240
+ ((num_experts + 1 ) * num_experts) * sizeof (int32_t );
241
+ const int32_t mem_cumsum = (num_experts + 1 ) * sizeof (int32_t );
242
+ // allocate global memory
243
+ int32_t * tokens_cnts;
244
+ int32_t * cumsum;
245
+ cudaMalloc (&tokens_cnts, mem_tokens_cnts);
246
+ cudaMalloc (&cumsum, mem_cumsum);
247
+
248
+ auto kernel =
249
+ vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t >;
250
+ kernel<<<1 , num_thread, 0 , stream>>> (
251
+ topk_ids.data_ptr <scalar_t >(),
252
+ sorted_token_ids.data_ptr <int32_t >(),
253
+ experts_ids.data_ptr <int32_t >(),
254
+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
255
+ topk_ids.numel (), tokens_cnts, cumsum);
256
+ cudaFree (tokens_cnts);
257
+ cudaFree (cumsum);
258
+ });
259
+ } else {
260
+ VLLM_DISPATCH_INTEGRAL_TYPES (
261
+ topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
262
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263
+ // tensors
264
+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
265
+ const int32_t shared_mem =
266
+ ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
267
+ sizeof (int32_t );
268
+
269
+ // set dynamic shared mem
270
+ auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
271
+ AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
272
+ (void *)kernel, shared_mem));
273
+ kernel<<<1 , num_thread, shared_mem, stream>>> (
274
+ topk_ids.data_ptr <scalar_t >(),
275
+ sorted_token_ids.data_ptr <int32_t >(),
276
+ experts_ids.data_ptr <int32_t >(),
277
+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
278
+ topk_ids.numel ());
279
+ });
280
+ }
159
281
}
160
282
161
283
void moe_sum (torch::Tensor& input, // [num_tokens, topk, hidden_size]
0 commit comments