-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mismatch between Bitblas result and torch.matmul in QuickStart.md with batch size > 1 #56
Comments
@MekkCyber The code works fine on my env: import bitblas
import torch
# enabling debug output
bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=4, # M dimension
N=1024, # N dimension
K=1024, # K dimension
A_dtype="float16", # activation A dtype
W_dtype="int4", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
matmul = bitblas.Matmul(config=matmul_config)
# Create input matrices
input_tensor = torch.rand((4, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 7, (1024, 1024), dtype=torch.int8).cuda()
# Transform weight tensor to int4 data type
weight_tensor_int4 = matmul.transform_weight(weight_tensor)
# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int4)
# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0) It seems you may have forgotten to set the input tensor to we should consider appending a static shape check in future releases. |
Yeah sorry I didn't set M to 4 in matmul_config ! thanks for your help @LeiWang1999 |
@LeiWang1999 I would like to ask a follow-up question, if we were running language models that had dynamic batch/sequence size, how should we configure the bitblas matmul in those cases? |
@w32zhong , set M into a List, for example, [1, 16, 32, 64], and BitBlas will generate dynamic kernels for each M, dispatching shapes like 31 to the kernel with M=32. |
Thank you so much. I actually just figured it out. Looks like dynamic shape will also match other batch shapes not specified by M, I guess those M values get passed in as tuple will get particularly optimized after hardware finetune? |
yes, you may also want to checkout the generated cuda source from ~/.cache/bitblas/nvidia . @w32zhong |
Hi everyone !
When I try to run the first script in QuickStart.md, using a batch size of 1 in the input tensor it works fine, i have the following result :
But once i use a batch size of 4 for example, I have too many mismatches
Is this an intended behaviour ? Thanks for your help
I am using
Python 3.10.14
bitblas 0.0.1.dev12
Tasks
The text was updated successfully, but these errors were encountered: