Skip to content

vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas

logger module-attribute

logger = init_logger(__name__)

BitBLASLinearKernel

Bases: MPLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
class BitBLASLinearKernel(MPLinearKernel):

    OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES
    ENABLE_TUNING: bool = True
    MATMUL_LAYOUT: str = "nt"
    BITBLAS_DTYPES: dict[torch.dtype, str] = {
        torch.float32: "float32",
        torch.float16: "float16",
        torch.bfloat16: "bfloat16",
        torch.half: "float16",
        torch.int8: "int8",
    }
    bitblas_matmul: object = None

    def __init__(
        self,
        c: MPLinearLayerConfig,
        w_q_param_name: str,
        w_s_param_name: str,
        w_zp_param_name: Optional[str] = None,
        w_gidx_param_name: Optional[str] = None,
        bitblas_quant_config: Optional[QuantizationConfig] = None,
    ):
        self.quant_config = bitblas_quant_config
        super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
                         w_gidx_param_name)

    def repack_bitblas_from_gptq(
        self,
        b_q_weight: torch.Tensor,
        scales: torch.Tensor,
        qzeros: Optional[torch.Tensor] = None,
    ):
        from bitblas.quantization.utils import general_compress
        assert self.bitblas_matmul is not None, "bitblas_matmul is None"

        quant_config = self.quant_config
        # qweight in gptq old quant linear stored with
        # (outfeatures, infeatures), should be transposed.
        qweight = b_q_weight.T.contiguous().view(
            quant_config.torch_storage_dtype)  # type: ignore[union-attr]
        intweight = unpack_gptq_qweight(
            qweight,
            quant_config.weight_bits).contiguous()  # type: ignore[union-attr]
        if self.bitblas_matmul.weight_transform is not None:  # type: ignore[attr-defined]
            qweight = self.bitblas_matmul.weight_transform(  # type: ignore[attr-defined]
                intweight.cpu()).cuda()
        # scales in gptq old quant linear stored with
        # (infeatures // group_size, outfeatures), should be transposed.
        scales = scales.T.contiguous()

        if qzeros is None:
            return qweight, scales, None

        # qzeros should be de-quantized to int zeros.
        weight_bits = quant_config.weight_bits  # type: ignore[union-attr]
        intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
        zeros: Optional[torch.Tensor] = None
        zeros_mode = self.bitblas_matmul.config.zeros_mode  # type: ignore[attr-defined]
        if zeros_mode == "original":
            zeros = intzeros.to(torch.float16).contiguous()
        elif zeros_mode == "rescale":
            assert zeros is not None, "zeros should not be None"
            zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
        elif zeros_mode == "quantized":
            zeros = (
                torch.Tensor(
                    general_compress(
                        intzeros.T.contiguous().cpu().numpy(),
                        weight_bits,
                    )).to(qweight.device).
                to(quant_config.torch_storage_dtype  # type: ignore[union-attr]
                   ).contiguous())
        else:
            raise ValueError("Unsupported zeros type: {}".format(zeros_mode))

        return qweight, scales, zeros

    @classmethod
    def get_min_capability(cls) -> int:
        return 70

    @classmethod
    def can_implement(cls,
                      c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:

        is_bitblas_installed = True

        try:
            import bitblas
            if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
                raise ImportError(
                    "bitblas version is wrong. Please "
                    f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
        except ImportError:
            is_bitblas_installed = False

        if not is_bitblas_installed:
            return False, "bitblas is not installed. Please install bitblas "\
                          "by running `pip install bitblas>="\
                           f"{MINIMUM_BITBLAS_VERSION}`"

        quant_types = query_bitblas_supported_quant_types(c.zero_points)
        if c.weight_type not in quant_types:
            return False, (f"Quant type ({c.weight_type}) not supported by"
                           f"  BitBLAS, supported types are: {quant_types}")

        if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
            return False, (f"Group size ({c.group_size}) not supported by "
                           "BitBLAS, supported group sizes are: "
                           f"{BITBLAS_SUPPORTED_GROUP_SIZES}")

        return check_bitblas_supports_shape(
            c.partition_weight_shape[1],  # out_features
            c.partition_weight_shape[0],  # in_features
            c.full_weight_shape[0],  # in_features
            c.group_size)

    # note assumes that
    #  `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
    #  `weight_scale` is: {input_dim = 0, output_dim = 1}
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        device = getattr(layer, self.w_q_name).device
        c = self.config
        quant_config = self.quant_config

        # Default names since bitblas requires empty parameters for these,
        # TODO: remove this requirement from bitblas (allow optional tensors)
        if self.w_gidx_name is None:
            self.w_gidx_name = "g_idx"
        if self.w_zp_name is None:
            self.w_zp_name = "qzeros"

        if c.has_g_idx:
            g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
                getattr(layer, self.w_gidx_name))
            self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
            layer.g_idx_sort_indices = g_idx_sort_indices
        else:
            setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
            layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)

        if c.zero_points:
            raise NotImplementedError("Zero points not supported by BitBLAS")
        else:
            setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))

        # Repack weights
        bitblas_qweight, bitblas_scales, bitblas_qzeros = (
            self.repack_bitblas_from_gptq(
                layer.qweight,
                layer.scales,
                None if quant_config.is_sym else  # type: ignore[union-attr]
                layer.qzeros,  # type: ignore[union-attr]
            ))
        replace_parameter(layer, self.w_q_name, bitblas_qweight)
        replace_parameter(layer, self.w_s_name, bitblas_scales)
        if bitblas_qzeros is not None:
            replace_parameter(layer, self.w_zp_name, bitblas_qzeros)

    def configure_bitblas_matmul(
        self,
        infeatures: int,
        outfeatures: int,
        params_dtype: torch.dtype,
        bias: bool,
    ) -> None:
        enable_tuning = self.ENABLE_TUNING
        layout = self.MATMUL_LAYOUT
        bits = self.quant_config.weight_bits  # type: ignore[union-attr]
        self._configure_bitblas_matmul(
            infeatures,
            outfeatures,
            params_dtype,
            enable_tuning,
            bias,
            layout,
            bits,
        )

    def _configure_bitblas_matmul(
        self,
        infeatures,
        outfeatures,
        params_dtype,
        enable_tuning,
        bias,
        layout,
        bits,
    ):
        from bitblas import MatmulConfig
        bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
        quant_config = self.quant_config
        with_scaling = False
        with_zeros = False
        group_size = quant_config.group_size  # type: ignore[union-attr]
        zeros_mode = quant_config.zeros_mode  # type: ignore[union-attr]
        if quant_config.quant_method == "gptq":  # type: ignore[union-attr]
            with_scaling = True
            with_zeros = True
            W_dtype = f"uint{bits}"
            if quant_config.is_sym:  # type: ignore[union-attr]
                with_zeros = False
                W_dtype = f"int{bits}"
        else:
            raise ValueError(
                f"Unsupported quant_method {quant_config.quant_method}"  # type: ignore[union-attr]
            )  # type: ignore[union-attr]

        matmul_config = MatmulConfig(
            M=self.OPT_FEATURES,
            N=outfeatures,
            K=infeatures,
            A_dtype=bitblas_dtype,
            W_dtype=W_dtype,
            out_dtype=bitblas_dtype,
            accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
            storage_dtype=quant_config.  # type: ignore[union-attr]
            storage_dtype,  # type: ignore[union-attr]
            with_scaling=with_scaling,
            with_zeros=with_zeros,
            group_size=group_size,
            with_bias=bias,
            layout=layout,
            zeros_mode=zeros_mode,
        )
        self.bitblas_matmul = self._get_or_create_bitblas_operator(
            matmul_config, enable_tuning)

    def _get_or_create_bitblas_operator(self, config, enable_tuning):
        from bitblas import Matmul, auto_detect_nvidia_target
        from bitblas.cache import get_database_path, global_operator_cache
        BITBLAS_DATABASE_PATH = get_database_path()
        BITBLAS_TARGET = auto_detect_nvidia_target()

        if global_operator_cache.size() == 0:
            global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
                                                     BITBLAS_TARGET)

        bitblas_matmul = global_operator_cache.get(config)
        if bitblas_matmul is None:
            bitblas_matmul = Matmul(config,
                                    target=BITBLAS_TARGET,
                                    enable_tuning=False)
            if enable_tuning:
                bitblas_matmul.hardware_aware_finetune(topk=20)
                global_operator_cache.add(config, bitblas_matmul)
                global_operator_cache.save_into_database(
                    BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
                TUNING_MESSAGE = (
                    f"BitBLAS Operator {config} tuned and saved to database.")
                logger.info(TUNING_MESSAGE)
            else:
                _message = f"BitBLAS Operator {config} created without tuning. "
                logger.info(_message)
        else:
            _message = f"BitBLAS Operator {config} retrieved from cache."
            logger.info(_message)
        return bitblas_matmul

    def apply_gptq_bitblas_linear(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
    ) -> torch.Tensor:
        output_size_per_partition = self.config.partition_weight_shape[1]
        out_shape = x.shape[:-1] + (output_size_per_partition, )
        args = [x, layer.qweight, layer.scales]
        if self.bitblas_matmul.config.with_zeros:  # type: ignore[attr-defined]
            args.append(layer.qzeros)
        output = self.bitblas_matmul(*args)  # type: ignore[operator]
        return output.view(out_shape)

    def apply_weights(self, layer, x, bias=None):
        NOT_IMPLEMENT_MESSAGE = (
            f"{self.__class__.__name__}.apply_weights is not implemented. "
            "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
        raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)

BITBLAS_DTYPES class-attribute instance-attribute

BITBLAS_DTYPES: dict[dtype, str] = {
    float32: "float32",
    float16: "float16",
    bfloat16: "bfloat16",
    half: "float16",
    int8: "int8",
}

ENABLE_TUNING class-attribute instance-attribute

ENABLE_TUNING: bool = True

MATMUL_LAYOUT class-attribute instance-attribute

MATMUL_LAYOUT: str = 'nt'

OPT_FEATURES class-attribute instance-attribute

bitblas_matmul class-attribute instance-attribute

bitblas_matmul: object = None

quant_config instance-attribute

quant_config = bitblas_quant_config

__init__

__init__(
    c: MPLinearLayerConfig,
    w_q_param_name: str,
    w_s_param_name: str,
    w_zp_param_name: Optional[str] = None,
    w_gidx_param_name: Optional[str] = None,
    bitblas_quant_config: Optional[
        QuantizationConfig
    ] = None,
)
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def __init__(
    self,
    c: MPLinearLayerConfig,
    w_q_param_name: str,
    w_s_param_name: str,
    w_zp_param_name: Optional[str] = None,
    w_gidx_param_name: Optional[str] = None,
    bitblas_quant_config: Optional[QuantizationConfig] = None,
):
    self.quant_config = bitblas_quant_config
    super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
                     w_gidx_param_name)

_configure_bitblas_matmul

_configure_bitblas_matmul(
    infeatures,
    outfeatures,
    params_dtype,
    enable_tuning,
    bias,
    layout,
    bits,
)
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def _configure_bitblas_matmul(
    self,
    infeatures,
    outfeatures,
    params_dtype,
    enable_tuning,
    bias,
    layout,
    bits,
):
    from bitblas import MatmulConfig
    bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
    quant_config = self.quant_config
    with_scaling = False
    with_zeros = False
    group_size = quant_config.group_size  # type: ignore[union-attr]
    zeros_mode = quant_config.zeros_mode  # type: ignore[union-attr]
    if quant_config.quant_method == "gptq":  # type: ignore[union-attr]
        with_scaling = True
        with_zeros = True
        W_dtype = f"uint{bits}"
        if quant_config.is_sym:  # type: ignore[union-attr]
            with_zeros = False
            W_dtype = f"int{bits}"
    else:
        raise ValueError(
            f"Unsupported quant_method {quant_config.quant_method}"  # type: ignore[union-attr]
        )  # type: ignore[union-attr]

    matmul_config = MatmulConfig(
        M=self.OPT_FEATURES,
        N=outfeatures,
        K=infeatures,
        A_dtype=bitblas_dtype,
        W_dtype=W_dtype,
        out_dtype=bitblas_dtype,
        accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
        storage_dtype=quant_config.  # type: ignore[union-attr]
        storage_dtype,  # type: ignore[union-attr]
        with_scaling=with_scaling,
        with_zeros=with_zeros,
        group_size=group_size,
        with_bias=bias,
        layout=layout,
        zeros_mode=zeros_mode,
    )
    self.bitblas_matmul = self._get_or_create_bitblas_operator(
        matmul_config, enable_tuning)

_get_or_create_bitblas_operator

_get_or_create_bitblas_operator(config, enable_tuning)
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def _get_or_create_bitblas_operator(self, config, enable_tuning):
    from bitblas import Matmul, auto_detect_nvidia_target
    from bitblas.cache import get_database_path, global_operator_cache
    BITBLAS_DATABASE_PATH = get_database_path()
    BITBLAS_TARGET = auto_detect_nvidia_target()

    if global_operator_cache.size() == 0:
        global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
                                                 BITBLAS_TARGET)

    bitblas_matmul = global_operator_cache.get(config)
    if bitblas_matmul is None:
        bitblas_matmul = Matmul(config,
                                target=BITBLAS_TARGET,
                                enable_tuning=False)
        if enable_tuning:
            bitblas_matmul.hardware_aware_finetune(topk=20)
            global_operator_cache.add(config, bitblas_matmul)
            global_operator_cache.save_into_database(
                BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
            TUNING_MESSAGE = (
                f"BitBLAS Operator {config} tuned and saved to database.")
            logger.info(TUNING_MESSAGE)
        else:
            _message = f"BitBLAS Operator {config} created without tuning. "
            logger.info(_message)
    else:
        _message = f"BitBLAS Operator {config} retrieved from cache."
        logger.info(_message)
    return bitblas_matmul

apply_gptq_bitblas_linear

apply_gptq_bitblas_linear(
    layer: Module, x: Tensor
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def apply_gptq_bitblas_linear(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
) -> torch.Tensor:
    output_size_per_partition = self.config.partition_weight_shape[1]
    out_shape = x.shape[:-1] + (output_size_per_partition, )
    args = [x, layer.qweight, layer.scales]
    if self.bitblas_matmul.config.with_zeros:  # type: ignore[attr-defined]
        args.append(layer.qzeros)
    output = self.bitblas_matmul(*args)  # type: ignore[operator]
    return output.view(out_shape)

apply_weights

apply_weights(layer, x, bias=None)
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def apply_weights(self, layer, x, bias=None):
    NOT_IMPLEMENT_MESSAGE = (
        f"{self.__class__.__name__}.apply_weights is not implemented. "
        "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
    raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)

can_implement classmethod

can_implement(
    c: MPLinearLayerConfig,
) -> tuple[bool, Optional[str]]
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
@classmethod
def can_implement(cls,
                  c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:

    is_bitblas_installed = True

    try:
        import bitblas
        if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
            raise ImportError(
                "bitblas version is wrong. Please "
                f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
    except ImportError:
        is_bitblas_installed = False

    if not is_bitblas_installed:
        return False, "bitblas is not installed. Please install bitblas "\
                      "by running `pip install bitblas>="\
                       f"{MINIMUM_BITBLAS_VERSION}`"

    quant_types = query_bitblas_supported_quant_types(c.zero_points)
    if c.weight_type not in quant_types:
        return False, (f"Quant type ({c.weight_type}) not supported by"
                       f"  BitBLAS, supported types are: {quant_types}")

    if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
        return False, (f"Group size ({c.group_size}) not supported by "
                       "BitBLAS, supported group sizes are: "
                       f"{BITBLAS_SUPPORTED_GROUP_SIZES}")

    return check_bitblas_supports_shape(
        c.partition_weight_shape[1],  # out_features
        c.partition_weight_shape[0],  # in_features
        c.full_weight_shape[0],  # in_features
        c.group_size)

configure_bitblas_matmul

configure_bitblas_matmul(
    infeatures: int,
    outfeatures: int,
    params_dtype: dtype,
    bias: bool,
) -> None
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def configure_bitblas_matmul(
    self,
    infeatures: int,
    outfeatures: int,
    params_dtype: torch.dtype,
    bias: bool,
) -> None:
    enable_tuning = self.ENABLE_TUNING
    layout = self.MATMUL_LAYOUT
    bits = self.quant_config.weight_bits  # type: ignore[union-attr]
    self._configure_bitblas_matmul(
        infeatures,
        outfeatures,
        params_dtype,
        enable_tuning,
        bias,
        layout,
        bits,
    )

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
@classmethod
def get_min_capability(cls) -> int:
    return 70

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    device = getattr(layer, self.w_q_name).device
    c = self.config
    quant_config = self.quant_config

    # Default names since bitblas requires empty parameters for these,
    # TODO: remove this requirement from bitblas (allow optional tensors)
    if self.w_gidx_name is None:
        self.w_gidx_name = "g_idx"
    if self.w_zp_name is None:
        self.w_zp_name = "qzeros"

    if c.has_g_idx:
        g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
            getattr(layer, self.w_gidx_name))
        self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
        layer.g_idx_sort_indices = g_idx_sort_indices
    else:
        setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
        layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)

    if c.zero_points:
        raise NotImplementedError("Zero points not supported by BitBLAS")
    else:
        setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))

    # Repack weights
    bitblas_qweight, bitblas_scales, bitblas_qzeros = (
        self.repack_bitblas_from_gptq(
            layer.qweight,
            layer.scales,
            None if quant_config.is_sym else  # type: ignore[union-attr]
            layer.qzeros,  # type: ignore[union-attr]
        ))
    replace_parameter(layer, self.w_q_name, bitblas_qweight)
    replace_parameter(layer, self.w_s_name, bitblas_scales)
    if bitblas_qzeros is not None:
        replace_parameter(layer, self.w_zp_name, bitblas_qzeros)

repack_bitblas_from_gptq

repack_bitblas_from_gptq(
    b_q_weight: Tensor,
    scales: Tensor,
    qzeros: Optional[Tensor] = None,
)
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py
def repack_bitblas_from_gptq(
    self,
    b_q_weight: torch.Tensor,
    scales: torch.Tensor,
    qzeros: Optional[torch.Tensor] = None,
):
    from bitblas.quantization.utils import general_compress
    assert self.bitblas_matmul is not None, "bitblas_matmul is None"

    quant_config = self.quant_config
    # qweight in gptq old quant linear stored with
    # (outfeatures, infeatures), should be transposed.
    qweight = b_q_weight.T.contiguous().view(
        quant_config.torch_storage_dtype)  # type: ignore[union-attr]
    intweight = unpack_gptq_qweight(
        qweight,
        quant_config.weight_bits).contiguous()  # type: ignore[union-attr]
    if self.bitblas_matmul.weight_transform is not None:  # type: ignore[attr-defined]
        qweight = self.bitblas_matmul.weight_transform(  # type: ignore[attr-defined]
            intweight.cpu()).cuda()
    # scales in gptq old quant linear stored with
    # (infeatures // group_size, outfeatures), should be transposed.
    scales = scales.T.contiguous()

    if qzeros is None:
        return qweight, scales, None

    # qzeros should be de-quantized to int zeros.
    weight_bits = quant_config.weight_bits  # type: ignore[union-attr]
    intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
    zeros: Optional[torch.Tensor] = None
    zeros_mode = self.bitblas_matmul.config.zeros_mode  # type: ignore[attr-defined]
    if zeros_mode == "original":
        zeros = intzeros.to(torch.float16).contiguous()
    elif zeros_mode == "rescale":
        assert zeros is not None, "zeros should not be None"
        zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
    elif zeros_mode == "quantized":
        zeros = (
            torch.Tensor(
                general_compress(
                    intzeros.T.contiguous().cpu().numpy(),
                    weight_bits,
                )).to(qweight.device).
            to(quant_config.torch_storage_dtype  # type: ignore[union-attr]
               ).contiguous())
    else:
        raise ValueError("Unsupported zeros type: {}".format(zeros_mode))

    return qweight, scales, zeros