Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/registration.h"

#include "cutlass/cutlass.h"
#include <limits>

#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
Expand Down Expand Up @@ -169,6 +170,11 @@ struct W4A8GemmKernel {
int k = A.size(1);
int n = B.size(1);

// safely cast group_size to int
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size);

// Allocate output
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto device = A.device();
Expand All @@ -192,7 +198,7 @@ struct W4A8GemmKernel {
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);

// strides
int const scale_k = cutlass::ceil_div(k, group_size);
int const scale_k = cutlass::ceil_div(k, group_size_int);
StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
// Reverse stride here due to swap and transpose
Expand All @@ -211,8 +217,8 @@ struct W4A8GemmKernel {
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;

MainloopArguments mainloop_arguments{
B_ptr, layout_B_reordered, A_ptr, stride_A,
S_ptr, stride_S, group_size};
B_ptr, layout_B_reordered, A_ptr, stride_A,
S_ptr, stride_S, group_size_int};

EpilogueArguments epilogue_arguments{
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
Expand Down