diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu index fdac47c425d6..bdaf2acb4136 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -11,6 +11,7 @@ #include "core/registration.h" #include "cutlass/cutlass.h" +#include #include "cute/tensor.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" @@ -169,6 +170,11 @@ struct W4A8GemmKernel { int k = A.size(1); int n = B.size(1); + // safely cast group_size to int + TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits::max(), + "group_size out of supported range for int: ", group_size); + int const group_size_int = static_cast(group_size); + // Allocate output const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); auto device = A.device(); @@ -192,7 +198,7 @@ struct W4A8GemmKernel { cute::tile_to_shape(LayoutAtomQuant{}, shape_B); // strides - int const scale_k = cutlass::ceil_div(k, group_size); + int const scale_k = cutlass::ceil_div(k, group_size_int); StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); // Reverse stride here due to swap and transpose @@ -211,8 +217,8 @@ struct W4A8GemmKernel { using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; MainloopArguments mainloop_arguments{ - B_ptr, layout_B_reordered, A_ptr, stride_A, - S_ptr, stride_S, group_size}; + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size_int}; EpilogueArguments epilogue_arguments{ ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),