Skip to content

Conversation

usamec
Copy link

@usamec usamec commented Aug 11, 2025

Fixes #311 by fixing a small typo in the packing script (bit permutation had a collision).

Previously, even a basic modified example from README for uint1 x fp16 did not work, now it is fixed, I provide it here for the context:

import bitblas
import torch

# uncomment to enable debug output
# bitblas.set_log_level("Debug")

matmul_config = bitblas.MatmulConfig(
    M=1,  # M dimension
    N=2048,  # N dimension
    K=1024,  # K dimension
    A_dtype="float16",  # activation A dtype
    W_dtype="uint1",  # weight W dtype
    accum_dtype="float16",  # accumulation dtype
    out_dtype="float16",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 2, (2048, 1024), dtype=torch.int8).cuda()

# Transform weight tensor to int4 data type
weight_tensor_int4 = matmul.transform_weight(weight_tensor)

# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int4)

# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)

@yiakwy-xpu-ml-framework-team

@usamec Good Job!

@LeiWang1999
Copy link
Contributor

Thanks, but I have left Microsoft and no longer have permission to merge this pull request/..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

transform_weight broken for uint1
3 participants