Skip to content
forked from microsoft/BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.


Notifications You must be signed in to change notification settings



Repository files navigation


For using the out-dated llvm-config:

git clone
cd ncurses5-compat-libs/
gpg --recv-keys CC2AF4472167BE03
makepkg -sir
conda create -n bitblas python=3.9
conda activate bitblas
conda install gcc_linux-64 gxx_linux-64
conda install cuda -c nvidia/label/cuda-12.1
pip install --pre torch torchvision torchaudio --index-url
conda install cmake
python build

In my case, I need to fix building by

cd 3rdparty/tvm
git apply my_patch_to_tvm.patch
cd ./build
cmake ..

or edit CMakeFiles/tvm.dir/link.txt and insert -lcuda near every -lxml2.

For running TVM Python interface:

pip install decorator psutil attrs thefuzz pytest tqdm

To test:

cd build/lib
python -c "import bitblas; print(bitblas.__version__)"

Hello World

  1. Make changes
  2. python build
  3. python

To avoid build C++, comment build_tvm(llvm_path) in BitBLASBuilPydCommand of

Code Structure

tvm is imported from BitBLAS/build/lib/bitblas/3rdparty/tvm/python. bitblas-related modules are imported from BitBLAS/python/bitblas.


graph TD;
   bitblas_bitnet_example[<a href="">bitblas_bitnet_example</a>]
   bitblas_matmul_init[<a href="">BitLinear.bitblas_matmul = Matmul of Operator parent</a>]
   bitblas_matmul[<a href="">BitLinear.bitblas_matmul.forward</a>]
   transform_weight_call[<a href="">BitLinear.bitblas_matmul.transform_weight call in post_process_weights</a>]
   transform_weight[<a href="">BitLinear.bitblas_matmul.transform_weight</a>]
   general_compress[<a href="">general_compress</a>]
   bitblas_bitnet_example --> bitblas_matmul_init
   bitblas_matmul_init --> bitblas_matmul
   bitblas_matmul_init --> transform_weight_call --> transform_weight --> general_compress

   matmul_init[<a href="">Matmul.__init__</a>]
   matmul_forward[<a href="">Matmul.forward</a>]
   _forward_from_prebuild_lib[<a href="">Operator._forward_from_prebuild_lib</a>]
   operator_libcall[<a href=""> of forward values</a>]
   _build_default_module_call[<a href="">Matmul._build_default_module call</a>]
   _build_default_module[<a href="">Matmul._build_default_module</a>]
   _build_runtime_module[<a href="">Operator._build_runtime_module</a>]
   operator_lib_init_call[<a href="">Operator.lib = self.wrapper.load_lib call</a>]
   tvm_build_call[<a href=""> call of self.optimized_func</a>]
   bitblas_matmul_init --> matmul_init
   matmul_init --> _build_default_module_call --> _build_default_module --> _build_runtime_module --> GPU --> operator_lib_init_call --> operator_libcall
   bitblas_matmul --> matmul_forward --> _forward_from_prebuild_lib --> operator_libcall
   _build_runtime_module --> CPU --> tvm_build_call
   tvm_build_call --> Operator.rt_mod
   tvm_build_call --> Operator.function_handle
   tvm_build_call --> Operator.torch_func

   apply_default_schedule_call[<a href="">Matmul.optimized_func = apply_default_schedule of self.prim_func_mod</a>]
   apply_default_schedule[<a href="">Operator.apply_default_schedule</a>]
   ApplyDefaultSchedule[<a href="">ApplyDefaultSchedule</a>]
   _select_implementation_call[<a href="">Operator.prim_func_mod = self._select_implementation</a>]
   _select_implementation[<a href="">‎Matmul._select_implementation</a>]
   weight_dequantize_implementation[<a href="">weight_dequantize_implementation imported from select_implementation</a>]
   matmul_nt_dequantize_b[<a href="">matmul_nt_dequantize_b</a>]
   construct_tvm_graph[<a href="">te.compute call</a>]

   _build_default_module --> apply_default_schedule_call --> apply_default_schedule --> ApplyDefaultSchedule -->|wrapped| module_pass
   matmul_init --> _select_implementation_call --> _select_implementation --> weight_dequantize_implementation --> matmul_nt_dequantize_b --> construct_tvm_graph

For module_pass:

graph TD;
   module_pass[<a href="">module_pass</a>]
   _wrap_class_module_pass[<a href="">_wrap_class_module_pass</a> wraps a PyModulePass class]
   __init_handle_by_constructor__call[<a href="">__init_handle_by_constructor__ call</a>]
   ModulePass[<a href="">ModulePass</a>]
   register_object[<a href="">tvm._ffi.register_object</a>]
   _register_object[<a href="">_register_object</a>]
   tvm_runtime_obj[<a href="">tvm.runtime.Object</a>]
   tvm_runtime_obj_base[<a href="">tvm.runtime.ObjectBase</a>]
   __init_handle_by_constructor__[<a href="">__init_handle_by_constructor__</a>]
   handle[Object.handle = constructor call]
   MakeModulePass[<a href="">_ffi_transform_api.MakeModulePass</a>]
   _ffi_transform_api[<a href="">tvm.transform initialization</a>]
   _init_api[<a href="">_init_api</a>]
   get_global_func[<a href="">get_global_func</a>]
   _LIB_init[<a href="">_LIB</a>]

   module_pass -->|return| _wrap_class_module_pass --> pass_cls
   _wrap_class_module_pass --> __init_handle_by_constructor__call --> __init_handle_by_constructor__
   __init_handle_by_constructor__call -->|operand| MakeModulePass
   __init_handle_by_constructor__call -->|operand| _pass_func --> pass_cls.transform_module
   _ffi_transform_api --> _init_api --> _init_api_prefix --> get_global_func -->
   _LIB_init -->
   _init_api_prefix -->|define| MakeModulePass
   _wrap_class_module_pass -->|inhereted| ModulePass -->|inhereted| Pass  -->|inhereted| tvm_runtime_obj -->|inhereted| tvm_runtime_obj_base --> __init_handle_by_constructor__
   register_object --> _register_object 
   ModulePass -->|wrapped| register_object
   ModulePass -->|register| transform.ModulePass
   Pass -->|wrapped| register_object
   Pass -->|register| transform.Pass
   __init_handle_by_constructor__ --> handle

Important functions:

  • post_process_weights calls weight_quant on weights and do transform_weight.
  • weight_quant scale down and clamp to [-1, 1] using mean value before creating a ternary net.
  • transform_weight compress an integer matrix to a compact matrix of W_dtype

What is re-scaling? Below is the extracted example code for encoding.

group_size = 128
input_shape = (1, 1024)
weight_shape = (1024, 1024)
scaling_shape = (1024, 1024 // 128)
zeros_shape = (1024, 1024 // 128)
output_shape = (1, 1024)

scaling = torch.rand(scaling_shape, dtype=torch.float16).cuda()
zeros = torch.rand(zeros_shape, dtype=torch.float16).cuda()

# Compute reference result with manual scaling and zero-point adjustment
# rescale = (weight - zeros) * scaling
for i in range(in_features // group_size): # group number i in range(8)
    for j in range(group_size): # group j-th element/column
         # within each group, we use the same zeros and scaling factors.
         rescaling_tensor[:, i*group_size+j] = (weight_tensor[:, i*group_size+j] - zeros[:, i]) * scaling[:, i]

For decoding, below is the prim_func generated for A_dtype="float16" activations and W_dtype="uint4" weights:

# from tvm.script import ir as I
# from tvm.script import tir as T

class Module:
    def main(
      A: T.Buffer((1, 1024), "float16"),
      B: T.Buffer((1024, 512), "int8"), Scale: T.Buffer((1024, 8), "float16"),
      Zeros: T.Buffer((1024, 8), "float16"),
      D: T.Buffer((1, 1024), "float16")):
        # with T.block("root"):
        B_decode = T.alloc_buffer((1024, 1024), "float16")
        C = T.alloc_buffer((1, 1024), "float16")
        for n, k in T.grid(1024, 1024):
            with T.block("B_decode"):
                v_n, v_k = T.axis.remap("SS", [n, k]) # “S” (for spatial), “R” (for reduction)
                T.reads(B[v_n, v_k // 2], Zeros[v_n, v_k // 128], Scale[v_n, v_k // 128])
                T.writes(B_decode[v_n, v_k])
                B_decode[v_n, v_k] = # decompressing B
                     T.Cast("float16", T.bitwise_and(
                        T.shift_right(B[v_n, v_k // 2], T.Cast("int8", v_k % 2 * 4)),
                        T.int8(15) # b1111
                     Zeros[v_n, v_k // 128] # re-centering
                  * Scale[v_n, v_k // 128] # scaling 
        for i, j, k in T.grid(1, 1024, 1024):
            with T.block("C"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(A[v_i, v_k], B_decode[v_j, v_k])
                T.writes(C[v_i, v_j])
                with T.init():
                    C[v_i, v_j] = T.float16(0)
                C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B_decode[v_j, v_k] # matrix multiplication
        for i, j in T.grid(1, 1024):
            with T.block("D"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(C[v_i, v_j])
                T.writes(D[v_i, v_j])
                D[v_i, v_j] = C[v_i, v_j]




Code Example

Without re-scaling:

import sys
import os
path = os.path.join("/home/tk/Desktop/bitblas/BitBLAS", "./build/lib")
sys.path.insert(0, path)
import torch
import bitblas

class TvmLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, W_dtype="uint4"):
        matmul_config = bitblas.MatmulConfig(
            A_dtype="float16",  # activation A dtype
            W_dtype=W_dtype,  # weight W dtype
            accum_dtype="float16",  # accumulation dtype
            out_dtype="float16",  # output dtype
            layout="nt", # A is non-transpose and W is transpose
            group_size=-1, # 128,  # setting for grouped quantization
            with_scaling=False,  # setting for scaling factor
            with_zeros=False,  # setting for zeros
            zeros_mode="original",  # setting for how to calculating zeros
        self.matmul = bitblas.Matmul(config=matmul_config)
        init_W = torch.randint(0, 7, (out_features, in_features), dtype=torch.int8).cuda()

    def set_weight(self, origin_int_W):
        self.W = self.matmul.transform_weight(origin_int_W)
        self.W_ori = origin_int_W

    def forward(self, A):
        output = self.matmul(A, self.W)
        verify = A @ self.W_ori.T.half()
        assert torch.allclose(output, verify, atol=1e-2)
        return output

inp = torch.rand((1, 8), dtype=torch.float16).cuda()
new_module = TvmLinear(8, 5)
out = new_module.forward(inp)

With re-scaling:

class TvmLinear(torch.nn.Module):
    def __init__(self, batch_size, in_features, out_features,
        W_dtype="uint4", group_size=-1, debug=False, no_extra_mem=False, tuning=True):
        super().__init__() # set up torch module (e.g., _backward_hooks)

        matmul_config = bitblas.MatmulConfig(
            A_dtype="float16",  # activation A dtype
            W_dtype=W_dtype,  # weight W dtype
            accum_dtype="float16",  # accumulation dtype
            out_dtype="float16",  # output dtype
            with_scaling=True,  # setting for scaling factor
            with_zeros=True,  # setting for zeros
            zeros_mode="rescale",  # setting for how to calculating zeros
            fast_decoding=False # important! avoid post-processing (i.e., LOP3Permutate)
        self.group_size = group_size if group_size != -1 else in_features
        self.matmul = bitblas.Matmul(config=matmul_config, enable_tuning=tuning)

        # set random initial (binary) weights
        #init_W = torch.randint(0, 2, (out_features, in_features), dtype=torch.int8)

        self.bits = int(''.join(list(filter(str.isdigit, W_dtype))))
        self.scaling = None
        self.zeros = None
        self.weight = None # placeholder for T5 modeling access.
        self.debug = debug
        self.no_extra_mem = no_extra_mem

    def _set_quantized_weight(self, W_quant):
        self.W_store = self.matmul.transform_weight(W_quant)
        self.W_store = self.W_store.cuda()

    def set_weight(self, W):
        out_features, in_features = W.shape
        assert in_features % self.group_size == 0

        reshape = (out_features, -1, self.group_size)
        W_reshape = W.reshape(*reshape)
        group_max = W_reshape.max(-1).values
        group_min = W_reshape.min(-1).values
        group_max = group_max.unsqueeze(-1).expand(reshape)
        group_min = group_min.unsqueeze(-1).expand(reshape)
        #Q_min = -(2**(self.bits - 1))
        #Q_max = 2**(self.bits - 1) - 1
        Q_min = 0
        Q_max = 2**(self.bits) - 1
        ratio = (Q_max - Q_min) / (group_max - group_min)

        W_quant = ((W_reshape - group_min) * ratio).round() + Q_min
        W_quant =
        W_quant = W_quant.reshape(W.shape)

        if not self.no_extra_mem:
            self.scaling = 1 / ratio[:,:,0].clone().cuda()
            self.zeros = Q_min * self.scaling - group_min[:,:,0].clone().cuda()

        if self.debug:
            self.W_origin = W
            self.W_quant = W_quant
            self.debug_scaling = self.scaling.unsqueeze(-1).expand(reshape).reshape(W.shape)
            self.debug_zeros = self.zeros.unsqueeze(-1).expand(reshape).reshape(W.shape)
            self.debug_W = W_quant.float() * self.debug_scaling - self.debug_zeros

    def forward(self, A):
        output = self.matmul(A, self.W_store, scale=self.scaling, zeros=self.zeros)
        if self.debug:
            output_check = A @ self.W_origin.T
            print(torch.allclose(self.debug_W.half(), self.W_origin, atol=1e-1))
            print(torch.allclose(A @ self.debug_W.half().T, A @ self.W_origin.T, atol=1e-1))
        return output

inp = torch.rand((4, 1, 8), dtype=torch.float16).cuda()
M = TvmLinear(4, 8, 5, W_dtype="uint4", group_size=4)

W = torch.rand((5, 8), dtype=torch.float16).cuda()

out = M.forward(inp)
print('inp', inp)
print('W', W)
print('out', out)

assert torch.allclose(
assert torch.allclose(
    inp @ M.debug_W.T.half(),
    inp @ W.T,
assert torch.allclose(
    inp @ W.T,

Application example:

from transformers import AutoProcessor, MusicgenForConditionalGeneration
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small",

processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
inputs = processor(
    text=["80s pop track with bassy drums and synth"],
    padding=True, return_tensors="pt",
seq_len = inputs.input_ids.shape[-1]

linear_classes = (torch.nn.Linear, )
to_replace_lst = []
for key, module in model.named_modules():
    if isinstance(module, linear_classes):
        if module.bias is not None: continue
        old_W = module.weight.detach().cpu()
        path_fields = key.split('.')
        parent_key = '.'.join(path_fields[:-1])
        child_key = path_fields[-1]
        to_replace_lst.append((parent_key, child_key, old_W))

n_replaced = 0
for parent_key, child_key, old_W in to_replace_lst:
    parent = model.get_submodule(parent_key)
    delattr(parent, child_key)
    out_features, in_features = old_W.shape
    M = TvmLinear([2, seq_len], in_features, out_features,
        group_size=-1, W_dtype="uint4", no_extra_mem=False, tuning=False)
    # 589824 * 2 => 294912 * 1 + 768 * 2 + 768 * 2
    # save: 881,664 bytes
    setattr(parent, child_key, M)
    n_replaced += 1
    print(n_replaced, parent_key, child_key, old_W.shape)
    if n_replaced % 20 == 0:
        print('GC', gc.collect())


print('generating ...')
with torch.no_grad():
    audio_values = model.generate(**inputs, max_new_tokens=256)

import scipy
sampling_rate = model.config.audio_encoder.sampling_rate
data = audio_values[0, 0].detach().cpu().float().numpy()"output.wav", rate=sampling_rate, data=data)
mplayer -ao openal output.wav

Useful tools

To debug a TIR function:

import sys
import os
path = os.path.join("/home/tk/Desktop/bitblas/BitBLAS", "./build/lib")
path2 = os.path.join("/home/tk/Desktop/bitblas/BitBLAS", "./build/lib/bitblas/3rdparty/tvm/python")
sys.path.insert(0, path)
sys.path.insert(0, path2)

import tvm
from import IRModule
from tvm.script import tir as T
import numpy as np

def main(var_A: T.handle, B: T.Buffer((768, 384), "int8"), Scale: T.Buffer((768, 3), "float16"), Zeros: T.Buffer((768, 3), "float16"), var_D: T.handle):
    T.func_attr({"dequantize_info": {"B_decode": {"decode_block": "B_decode", "fast_decoding": T.bool(False), "group_size": 256, "source_format": {"bits": 4, "format": "uint"}, "storage_dtype": "int8", "target_format": "float16", "with_scaling": T.bool(True), "with_zeros": T.bool(True), "zeros_mode": "rescale"}}, "opt_shapes": {"m": [2, 12]}, "tir.noalias": T.bool(True)})
    m = T.int32()
    A = T.match_buffer(var_A, (m, 768), "float16")
    D = T.match_buffer(var_D, (m, 768), "float16")
    # with T.block("root"):
    B_decode = T.alloc_buffer((768, 768), "float16")
    C = T.alloc_buffer((m, 768), "float16")
    for n, k in T.grid(768, 768):
        with T.block("B_decode"):
            v_n, v_k = T.axis.remap("SS", [n, k])
            T.reads(B[v_n, v_k // 2], Scale[v_n, v_k // 256], Zeros[v_n, v_k // 256])
            T.writes(B_decode[v_n, v_k])
            B_decode[v_n, v_k] = T.Cast("float16", T.bitwise_and(T.shift_right(B[v_n, v_k // 2], T.Cast("int8", v_k % 2 * 4)), T.int8(15))) * Scale[v_n, v_k // 256] - Zeros[v_n, v_k // 256]
    for i, j, k in T.grid(m, 768, 768):
        with T.block("C"):
            v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
            T.reads(A[v_i, v_k], B_decode[v_j, v_k])
            T.writes(C[v_i, v_j])
            with T.init():
                C[v_i, v_j] = T.float16(0)
            C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B_decode[v_j, v_k]
    for i, j in T.grid(m, 768):
        with T.block("D"):
            v_i, v_j = T.axis.remap("SS", [i, j])
            T.reads(C[v_i, v_j])
            T.writes(D[v_i, v_j])
            D[v_i, v_j] = C[v_i, v_j]

rt_mod =, target='llvm')
func = rt_mod[rt_mod.entry_name]

A = tvm.nd.array(np.ones((2, 768), dtype="float16"))
B = tvm.nd.array(np.random.randint(0, 2, size=(768, 384), dtype="int8"))
Scale = tvm.nd.array(np.ones((768, 3), dtype="float16"))
Zeros = tvm.nd.array(np.ones((768, 3), dtype="float16"))
D = tvm.nd.array(np.ones((2, 768), dtype="float16"))
func(A, B, Scale, Zeros, D)


BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.



Code of conduct

Security policy





No releases published


No packages published


  • Python 86.5%
  • Cuda 8.1%
  • C++ 4.5%
  • Other 0.9%