Skip to content

vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin

MarlinLinearKernel

Bases: MPLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
class MarlinLinearKernel(MPLinearKernel):

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def can_implement(cls,
                      c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:

        quant_types = query_marlin_supported_quant_types(c.zero_points)
        if c.weight_type not in quant_types:
            return False, f"Quant type ({c.weight_type}) not supported by"\
                          f"  Marlin, supported types are: {quant_types}"

        if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
            return False, f"Group size ({c.group_size}) not supported by "\
                            "Marlin, supported group sizes are: "\
                            f"{MARLIN_SUPPORTED_GROUP_SIZES}"

        return check_marlin_supports_shape(
            c.partition_weight_shape[1],  # out_features
            c.partition_weight_shape[0],  # in_features
            c.full_weight_shape[0],  # in_features
            c.group_size)

    # note assumes that
    #  `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
    #  `weight_scale` is: {input_dim = 0, output_dim = 1}
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        device = getattr(layer, self.w_q_name).device
        c = self.config

        row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
        self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)

        # Allocate marlin workspace.
        self.workspace = marlin_make_workspace_new(device)

        # Default names since marlin requires empty parameters for these,
        # TODO: remove this requirement from marlin (allow optional tensors)
        if self.w_gidx_name is None:
            self.w_gidx_name = "g_idx"
        if self.w_zp_name is None:
            self.w_zp_name = "w_zp"

        def transform_w_q(x):
            assert isinstance(x, BasevLLMParameter)
            permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
            x.data = ops.gptq_marlin_repack(x.data.contiguous(),
                                            perm=layer.g_idx_sort_indices,
                                            size_k=c.partition_weight_shape[0],
                                            size_n=c.partition_weight_shape[1],
                                            num_bits=c.weight_type.size_bits)
            return x

        def transform_w_s(x):
            assert isinstance(x, BasevLLMParameter)
            permute_param_layout_(x, input_dim=0, output_dim=1)
            x.data = marlin_permute_scales(x.data.contiguous(),
                                           size_k=c.partition_weight_shape[0],
                                           size_n=c.partition_weight_shape[1],
                                           group_size=c.group_size)
            return x

        if c.has_g_idx:
            g_idx, g_idx_sort_indices = marlin_sort_g_idx(
                getattr(layer, self.w_gidx_name))
            self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
            layer.g_idx_sort_indices = g_idx_sort_indices
        else:
            setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
            layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

        if c.zero_points:
            grouped_k = (c.partition_weight_shape[0] //
                         c.group_size if c.group_size != -1 else 1)
            self._transform_param(layer, self.w_zp_name, lambda x: \
                marlin_zero_points(
                    unpack_cols(x.t(), c.weight_type.size_bits,
                                grouped_k,
                                c.partition_weight_shape[1]),
                    size_k=grouped_k,
                    size_n=c.partition_weight_shape[1],
                    num_bits=c.weight_type.size_bits))
        else:
            setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
        self._transform_param(layer, self.w_q_name, transform_w_q)
        self._transform_param(layer, self.w_s_name, transform_w_s)

    def apply_weights(self,
                      layer: torch.nn.Module,
                      x: torch.Tensor,
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        c = self.config
        w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)

        # `process_weights_after_loading` will ensure w_zp and w_gidx are not
        #  None for marlin
        return apply_gptq_marlin_linear(
            input=x,
            weight=w_q,
            weight_scale=w_s,
            weight_zp=w_zp,  # type: ignore
            g_idx=w_gidx,  # type: ignore
            g_idx_sort_indices=layer.g_idx_sort_indices,
            workspace=self.workspace,
            wtype=c.weight_type,
            input_size_per_partition=c.partition_weight_shape[0],
            output_size_per_partition=c.partition_weight_shape[1],
            is_k_full=self.is_k_full,
            bias=bias)

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
def apply_weights(self,
                  layer: torch.nn.Module,
                  x: torch.Tensor,
                  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    c = self.config
    w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)

    # `process_weights_after_loading` will ensure w_zp and w_gidx are not
    #  None for marlin
    return apply_gptq_marlin_linear(
        input=x,
        weight=w_q,
        weight_scale=w_s,
        weight_zp=w_zp,  # type: ignore
        g_idx=w_gidx,  # type: ignore
        g_idx_sort_indices=layer.g_idx_sort_indices,
        workspace=self.workspace,
        wtype=c.weight_type,
        input_size_per_partition=c.partition_weight_shape[0],
        output_size_per_partition=c.partition_weight_shape[1],
        is_k_full=self.is_k_full,
        bias=bias)

can_implement classmethod

can_implement(
    c: MPLinearLayerConfig,
) -> tuple[bool, Optional[str]]
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
@classmethod
def can_implement(cls,
                  c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:

    quant_types = query_marlin_supported_quant_types(c.zero_points)
    if c.weight_type not in quant_types:
        return False, f"Quant type ({c.weight_type}) not supported by"\
                      f"  Marlin, supported types are: {quant_types}"

    if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
        return False, f"Group size ({c.group_size}) not supported by "\
                        "Marlin, supported group sizes are: "\
                        f"{MARLIN_SUPPORTED_GROUP_SIZES}"

    return check_marlin_supports_shape(
        c.partition_weight_shape[1],  # out_features
        c.partition_weight_shape[0],  # in_features
        c.full_weight_shape[0],  # in_features
        c.group_size)

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    device = getattr(layer, self.w_q_name).device
    c = self.config

    row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
    self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)

    # Allocate marlin workspace.
    self.workspace = marlin_make_workspace_new(device)

    # Default names since marlin requires empty parameters for these,
    # TODO: remove this requirement from marlin (allow optional tensors)
    if self.w_gidx_name is None:
        self.w_gidx_name = "g_idx"
    if self.w_zp_name is None:
        self.w_zp_name = "w_zp"

    def transform_w_q(x):
        assert isinstance(x, BasevLLMParameter)
        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
        x.data = ops.gptq_marlin_repack(x.data.contiguous(),
                                        perm=layer.g_idx_sort_indices,
                                        size_k=c.partition_weight_shape[0],
                                        size_n=c.partition_weight_shape[1],
                                        num_bits=c.weight_type.size_bits)
        return x

    def transform_w_s(x):
        assert isinstance(x, BasevLLMParameter)
        permute_param_layout_(x, input_dim=0, output_dim=1)
        x.data = marlin_permute_scales(x.data.contiguous(),
                                       size_k=c.partition_weight_shape[0],
                                       size_n=c.partition_weight_shape[1],
                                       group_size=c.group_size)
        return x

    if c.has_g_idx:
        g_idx, g_idx_sort_indices = marlin_sort_g_idx(
            getattr(layer, self.w_gidx_name))
        self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
        layer.g_idx_sort_indices = g_idx_sort_indices
    else:
        setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
        layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

    if c.zero_points:
        grouped_k = (c.partition_weight_shape[0] //
                     c.group_size if c.group_size != -1 else 1)
        self._transform_param(layer, self.w_zp_name, lambda x: \
            marlin_zero_points(
                unpack_cols(x.t(), c.weight_type.size_bits,
                            grouped_k,
                            c.partition_weight_shape[1]),
                size_k=grouped_k,
                size_n=c.partition_weight_shape[1],
                num_bits=c.weight_type.size_bits))
    else:
        setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
    self._transform_param(layer, self.w_q_name, transform_w_q)
    self._transform_param(layer, self.w_s_name, transform_w_s)