11
11
#include " core/registration.h"
12
12
13
13
#include " cutlass/cutlass.h"
14
+ #include < limits>
14
15
15
16
#include " cute/tensor.hpp"
16
17
#include " cutlass/gemm/collective/collective_builder.hpp"
@@ -169,6 +170,11 @@ struct W4A8GemmKernel {
169
170
int k = A.size (1 );
170
171
int n = B.size (1 );
171
172
173
+ // safely cast group_size to int
174
+ TORCH_CHECK (group_size > 0 && group_size <= std::numeric_limits<int >::max (),
175
+ " group_size out of supported range for int: " , group_size);
176
+ int const group_size_int = static_cast <int >(group_size);
177
+
172
178
// Allocate output
173
179
const at::cuda::OptionalCUDAGuard device_guard (device_of (A));
174
180
auto device = A.device ();
@@ -192,7 +198,7 @@ struct W4A8GemmKernel {
192
198
cute::tile_to_shape (LayoutAtomQuant{}, shape_B);
193
199
194
200
// strides
195
- int const scale_k = cutlass::ceil_div (k, group_size );
201
+ int const scale_k = cutlass::ceil_div (k, group_size_int );
196
202
StrideA stride_A =
197
203
cutlass::make_cute_packed_stride (StrideA{}, cute::make_shape (m, k, 1 ));
198
204
// Reverse stride here due to swap and transpose
@@ -211,8 +217,8 @@ struct W4A8GemmKernel {
211
217
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
212
218
213
219
MainloopArguments mainloop_arguments{
214
- B_ptr, layout_B_reordered, A_ptr, stride_A,
215
- S_ptr, stride_S, group_size };
220
+ B_ptr, layout_B_reordered, A_ptr, stride_A,
221
+ S_ptr, stride_S, group_size_int };
216
222
217
223
EpilogueArguments epilogue_arguments{
218
224
ChTokScalesEpilogue::prepare_args (channel_scales, token_scales),
0 commit comments