Skip to content

vllm.model_executor.layers.quantization.moe_wna16

MoeWNA16Config

Bases: QuantizationConfig

Config class for MOE WNA16 (W8A16/W4A16) quantization.

Source code in vllm/model_executor/layers/quantization/moe_wna16.py
class MoeWNA16Config(QuantizationConfig):
    """Config class for MOE WNA16 (W8A16/W4A16) quantization."""

    def __init__(self, linear_quant_method: str, weight_bits: int,
                 group_size: int, has_zp: bool, lm_head_quantized: bool,
                 modules_to_not_convert: Optional[list[str]],
                 full_config: dict[str, Any]) -> None:
        super().__init__()
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.has_zp = has_zp
        self.bit8_pack_factor = 8 // self.weight_bits
        self.lm_head_quantized = lm_head_quantized
        self.linear_quant_method = linear_quant_method
        self.full_config = full_config
        self.use_marlin = False
        # Avoid circular import
        from vllm.model_executor.layers.quantization.awq import AWQConfig
        from vllm.model_executor.layers.quantization.awq_marlin import (
            AWQMarlinConfig)
        from vllm.model_executor.layers.quantization.gptq_marlin import (
            GPTQMarlinConfig)
        if self.linear_quant_method == "gptq":
            self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
                full_config)
        elif self.linear_quant_method == "awq":
            capability_tuple = current_platform.get_device_capability()
            device_capability = (-1 if capability_tuple is None else
                                 capability_tuple.to_int())
            awq_min_capability = AWQConfig.get_min_capability()
            if device_capability < awq_min_capability:
                raise ValueError(
                    "The quantization method moe_wna16 + awq is not supported "
                    "for the current GPU. "
                    f"Minimum capability: {awq_min_capability}. "
                    f"Current capability: {device_capability}.")
            self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(
                full_config)
        else:
            raise ValueError("moe_wna16 only support gptq and awq.")

        if modules_to_not_convert is None:
            self.modules_to_not_convert = []
        else:
            self.modules_to_not_convert = modules_to_not_convert

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "moe_wna16"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.half]

    @classmethod
    def get_min_capability(cls) -> int:
        return 70

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["quantize_config.json"]

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
        linear_quant_method = cls.get_from_keys(config, ["quant_method"])
        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
                                                 default=False)
        if linear_quant_method == "gptq":
            has_zp = not cls.get_from_keys(config, ["sym"])
            modules_to_not_convert = []
        elif linear_quant_method == "awq":
            has_zp = cls.get_from_keys(config, ["zero_point"])
            modules_to_not_convert = cls.get_from_keys_or(
                config, ["modules_to_not_convert"], None)
        else:
            raise ValueError("moe_wna16 only support gptq and awq.")

        return cls(linear_quant_method, weight_bits, group_size, has_zp,
                   lm_head_quantized, modules_to_not_convert, config)

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
        if can_convert and user_quant == "moe_wna16":
            return cls.get_name()
        return None

    @classmethod
    def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
        # Extract data from quant config.
        quant_method = quant_config.get("quant_method", "").lower()
        num_bits = quant_config.get("bits")
        desc_act = quant_config.get("desc_act")

        capability_tuple = current_platform.get_device_capability()
        device_capability = (-1 if capability_tuple is None else
                             capability_tuple.to_int())
        # Avoid circular import
        from vllm.model_executor.layers.quantization.awq import AWQConfig
        awq_min_capability = AWQConfig.get_min_capability()

        gptq_compatible = quant_method == "gptq" and \
                not desc_act and num_bits in [4, 8]
        awq_compatible = quant_method == "awq" and num_bits == 4 and \
            device_capability >= awq_min_capability

        return gptq_compatible or awq_compatible

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
            return UnquantizedLinearMethod()
        elif isinstance(layer, LinearBase):
            # Avoid circular import
            from vllm.model_executor.layers.quantization.awq import AWQConfig
            from vllm.model_executor.layers.quantization.awq_marlin import (
                AWQMarlinConfig)
            from vllm.model_executor.layers.quantization.gptq import GPTQConfig
            from vllm.model_executor.layers.quantization.gptq_marlin import (
                GPTQMarlinConfig)
            if self.linear_quant_method == "gptq":
                if self.use_marlin:
                    return GPTQMarlinConfig.from_config(
                        self.full_config).get_quant_method(layer, prefix)
                else:
                    return GPTQConfig.from_config(
                        self.full_config).get_quant_method(layer, prefix)
            elif self.linear_quant_method == "awq":
                if self.use_marlin and check_marlin_supports_layer(
                        layer, self.group_size):
                    return AWQMarlinConfig.from_config(
                        self.full_config).get_quant_method(layer, prefix)
                else:
                    return AWQConfig.from_config(
                        self.full_config).get_quant_method(layer, prefix)
            else:
                raise ValueError("moe_wna16 only support gptq and awq.")
        elif isinstance(layer, FusedMoE):
            return MoeWNA16Method(self)
        return None

bit8_pack_factor instance-attribute

bit8_pack_factor = 8 // weight_bits

full_config instance-attribute

full_config = full_config

group_size instance-attribute

group_size = group_size

has_zp instance-attribute

has_zp = has_zp

linear_quant_method instance-attribute

linear_quant_method = linear_quant_method

lm_head_quantized instance-attribute

lm_head_quantized = lm_head_quantized

modules_to_not_convert instance-attribute

modules_to_not_convert = []

use_marlin instance-attribute

use_marlin = False

weight_bits instance-attribute

weight_bits = weight_bits

__init__

__init__(
    linear_quant_method: str,
    weight_bits: int,
    group_size: int,
    has_zp: bool,
    lm_head_quantized: bool,
    modules_to_not_convert: Optional[list[str]],
    full_config: dict[str, Any],
) -> None
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
def __init__(self, linear_quant_method: str, weight_bits: int,
             group_size: int, has_zp: bool, lm_head_quantized: bool,
             modules_to_not_convert: Optional[list[str]],
             full_config: dict[str, Any]) -> None:
    super().__init__()
    self.weight_bits = weight_bits
    self.group_size = group_size
    self.has_zp = has_zp
    self.bit8_pack_factor = 8 // self.weight_bits
    self.lm_head_quantized = lm_head_quantized
    self.linear_quant_method = linear_quant_method
    self.full_config = full_config
    self.use_marlin = False
    # Avoid circular import
    from vllm.model_executor.layers.quantization.awq import AWQConfig
    from vllm.model_executor.layers.quantization.awq_marlin import (
        AWQMarlinConfig)
    from vllm.model_executor.layers.quantization.gptq_marlin import (
        GPTQMarlinConfig)
    if self.linear_quant_method == "gptq":
        self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
            full_config)
    elif self.linear_quant_method == "awq":
        capability_tuple = current_platform.get_device_capability()
        device_capability = (-1 if capability_tuple is None else
                             capability_tuple.to_int())
        awq_min_capability = AWQConfig.get_min_capability()
        if device_capability < awq_min_capability:
            raise ValueError(
                "The quantization method moe_wna16 + awq is not supported "
                "for the current GPU. "
                f"Minimum capability: {awq_min_capability}. "
                f"Current capability: {device_capability}.")
        self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(
            full_config)
    else:
        raise ValueError("moe_wna16 only support gptq and awq.")

    if modules_to_not_convert is None:
        self.modules_to_not_convert = []
    else:
        self.modules_to_not_convert = modules_to_not_convert

from_config classmethod

from_config(config: dict[str, Any]) -> MoeWNA16Config
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config":
    linear_quant_method = cls.get_from_keys(config, ["quant_method"])
    weight_bits = cls.get_from_keys(config, ["bits"])
    group_size = cls.get_from_keys(config, ["group_size"])
    lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
                                             default=False)
    if linear_quant_method == "gptq":
        has_zp = not cls.get_from_keys(config, ["sym"])
        modules_to_not_convert = []
    elif linear_quant_method == "awq":
        has_zp = cls.get_from_keys(config, ["zero_point"])
        modules_to_not_convert = cls.get_from_keys_or(
            config, ["modules_to_not_convert"], None)
    else:
        raise ValueError("moe_wna16 only support gptq and awq.")

    return cls(linear_quant_method, weight_bits, group_size, has_zp,
               lm_head_quantized, modules_to_not_convert, config)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["quantize_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@classmethod
def get_min_capability(cls) -> int:
    return 70

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "moe_wna16"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
        return UnquantizedLinearMethod()
    elif isinstance(layer, LinearBase):
        # Avoid circular import
        from vllm.model_executor.layers.quantization.awq import AWQConfig
        from vllm.model_executor.layers.quantization.awq_marlin import (
            AWQMarlinConfig)
        from vllm.model_executor.layers.quantization.gptq import GPTQConfig
        from vllm.model_executor.layers.quantization.gptq_marlin import (
            GPTQMarlinConfig)
        if self.linear_quant_method == "gptq":
            if self.use_marlin:
                return GPTQMarlinConfig.from_config(
                    self.full_config).get_quant_method(layer, prefix)
            else:
                return GPTQConfig.from_config(
                    self.full_config).get_quant_method(layer, prefix)
        elif self.linear_quant_method == "awq":
            if self.use_marlin and check_marlin_supports_layer(
                    layer, self.group_size):
                return AWQMarlinConfig.from_config(
                    self.full_config).get_quant_method(layer, prefix)
            else:
                return AWQConfig.from_config(
                    self.full_config).get_quant_method(layer, prefix)
        else:
            raise ValueError("moe_wna16 only support gptq and awq.")
    elif isinstance(layer, FusedMoE):
        return MoeWNA16Method(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.half]

is_moe_wna16_compatible classmethod

is_moe_wna16_compatible(quant_config: dict[str, Any])
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@classmethod
def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]):
    # Extract data from quant config.
    quant_method = quant_config.get("quant_method", "").lower()
    num_bits = quant_config.get("bits")
    desc_act = quant_config.get("desc_act")

    capability_tuple = current_platform.get_device_capability()
    device_capability = (-1 if capability_tuple is None else
                         capability_tuple.to_int())
    # Avoid circular import
    from vllm.model_executor.layers.quantization.awq import AWQConfig
    awq_min_capability = AWQConfig.get_min_capability()

    gptq_compatible = quant_method == "gptq" and \
            not desc_act and num_bits in [4, 8]
    awq_compatible = quant_method == "awq" and num_bits == 4 and \
        device_capability >= awq_min_capability

    return gptq_compatible or awq_compatible

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@classmethod
def override_quantization_method(
        cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
    can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
    if can_convert and user_quant == "moe_wna16":
        return cls.get_name()
    return None

MoeWNA16Method

Bases: FusedMoEMethodBase

Linear method for MOE WNA16 (W8A16/W4A16) quantization.

Parameters:

Name Type Description Default
quant_config MoeWNA16Config

The MOE WNA16 (W8A16/W4A16) quantization config.

required
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
class MoeWNA16Method(FusedMoEMethodBase):
    """Linear method for MOE WNA16 (W8A16/W4A16) quantization.

    Args:
        quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
    """

    def __init__(self, quant_config: MoeWNA16Config):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module, num_experts: int,
                       hidden_size: int, intermediate_size_per_partition: int,
                       params_dtype: torch.dtype, **extra_weight_attrs):

        layer.quant_config = self.quant_config
        bit8_pack_factor = self.quant_config.bit8_pack_factor
        group_size = self.quant_config.group_size
        group_size_div_factor = 1

        # make intermediate_size and hidden_size diviable by group_size
        # we reduce the group size to ensure that
        # and we would repeat the loaded_weight later
        while intermediate_size_per_partition % group_size or \
                hidden_size % group_size:
            group_size = group_size // 2
            group_size_div_factor *= 2
            assert group_size >= 32
        layer.group_size = group_size
        layer.group_size_div_factor = group_size_div_factor

        strategy = FusedMoeWeightScaleSupported.GROUP.value
        extra_weight_attrs.update({
            "quant_method": strategy,
            "is_transposed": False
        })

        assert 'weight_loader' in extra_weight_attrs
        weight_loader = extra_weight_attrs['weight_loader']
        wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
            layer, weight_loader)
        extra_weight_attrs['weight_loader'] = wrapped_weight_loader

        # Fused gate_up_proj (column parallel)
        w13_qweight = torch.nn.Parameter(torch.empty(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size // bit8_pack_factor,
            dtype=torch.uint8),
                                         requires_grad=False)
        layer.register_parameter("w13_qweight", w13_qweight)
        set_weight_attrs(w13_qweight, extra_weight_attrs)

        # down_proj (row parallel)
        w2_qweight = torch.nn.Parameter(torch.empty(
            num_experts,
            hidden_size,
            intermediate_size_per_partition // bit8_pack_factor,
            dtype=torch.uint8),
                                        requires_grad=False)
        layer.register_parameter("w2_qweight", w2_qweight)
        set_weight_attrs(w2_qweight, extra_weight_attrs)

        w13_scales = torch.nn.Parameter(torch.zeros(
            num_experts,
            2 * intermediate_size_per_partition,
            hidden_size // group_size,
            dtype=params_dtype),
                                        requires_grad=False)
        layer.register_parameter("w13_scales", w13_scales)
        set_weight_attrs(w13_scales, extra_weight_attrs)

        w2_scales = torch.nn.Parameter(torch.zeros(
            num_experts,
            hidden_size,
            intermediate_size_per_partition // group_size,
            dtype=params_dtype),
                                       requires_grad=False)
        layer.register_parameter("w2_scales", w2_scales)
        set_weight_attrs(w2_scales, extra_weight_attrs)

        if self.quant_config.has_zp:
            w13_qzeros = torch.nn.Parameter(torch.zeros(
                num_experts,
                2 * intermediate_size_per_partition // bit8_pack_factor,
                hidden_size // group_size,
                dtype=torch.uint8),
                                            requires_grad=False)
            layer.register_parameter("w13_qzeros", w13_qzeros)
            set_weight_attrs(w13_qzeros, extra_weight_attrs)

            w2_qzeros = torch.nn.Parameter(torch.zeros(
                num_experts,
                hidden_size // bit8_pack_factor,
                intermediate_size_per_partition // group_size,
                dtype=torch.uint8),
                                           requires_grad=False)
            layer.register_parameter("w2_qzeros", w2_qzeros)
            set_weight_attrs(w2_qzeros, extra_weight_attrs)

        if self.quant_config.linear_quant_method == "gptq":
            # some param are unused, but we need to init them in order to
            # load weights
            invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
            if not self.quant_config.has_zp:
                invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
            for key in invalid_param_keys:
                param = torch.nn.Parameter(torch.empty((0, ),
                                                       dtype=torch.int32),
                                           requires_grad=False)
                layer.register_parameter(key, param)
                set_weight_attrs(param, extra_weight_attrs)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `MoeWNA16Method` yet.")

        from vllm.model_executor.layers.fused_moe import fused_experts
        assert activation == "silu", "Only SiLU activation is supported."
        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)

        weight_bits = self.quant_config.weight_bits
        has_zp = self.quant_config.has_zp

        return fused_experts(
            x,
            layer.w13_qweight,
            layer.w2_qweight,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=True,
            use_int4_w4a16=weight_bits == 4,
            use_int8_w8a16=weight_bits == 8,
            global_num_experts=global_num_experts,
            apply_router_weight_on_input=apply_router_weight_on_input,
            expert_map=expert_map,
            w1_scale=layer.w13_scales,
            w2_scale=layer.w2_scales,
            w1_zp=layer.w13_qzeros if has_zp else None,
            w2_zp=layer.w2_qzeros if has_zp else None,
            block_shape=[0, layer.group_size])

    @staticmethod
    def get_weight_loader(layer, weight_loader):

        def convert_awq_tensor(tensor, tensor_type):
            # convert awq qweight/qzeros to a standard format (assume int4)
            # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
            # qzeros: (k // group_size, n // pack_factor_bit32) ->
            #         (n // pack_factor_bit8, k // group_size)
            # pack_factor_bit32 = 32 // weight_bits
            # pack_factor_bit8 = 8 // weight_bits

            # 0. suppose origin shape (a, b), dtype int32
            # 1. convert to uint8, shape (a, b) -> (a, 4 * b)
            size0 = tensor.size(0)
            tensor = tensor.view(torch.uint8)

            # 2. unpack to uint4 (only when weight_bits == 4)
            #    shape (a, 4 * b) -> (a, 4 * b, 2)
            shifter = torch.tensor([0, 4],
                                   dtype=torch.uint8,
                                   device=tensor.device)
            tensor = (tensor[:, :, None] >> shifter) & 0xF

            # 3. change order, see
            # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
            # shape -> (a, 4 * b * pack_factor_bit8)
            reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
            tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
            tensor = tensor.view(size0, -1)

            # 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
            tensor = tensor.T.contiguous()

            # 5. repack (only when weight_bits == 4)
            # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
            # qzeros shape -> (4 * b, a)

            if tensor_type == "qweight":
                tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
            elif tensor_type == "qzeros":
                tensor = tensor[1::2, :] * 16 + tensor[::2, :]
            return tensor

        def convert_gptq_int4_qzeros(tensor):
            tensor = tensor.view(torch.uint8)
            shifter = torch.tensor([0, 4],
                                   dtype=torch.uint8,
                                   device=tensor.device)
            tensor = (tensor[:, :, None] >> shifter) & 0xF
            tensor = tensor + 1
            tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
            return tensor

        def moe_wna16_weight_loader(param: torch.nn.Parameter,
                                    loaded_weight: torch.Tensor,
                                    weight_name: str, shard_id: str,
                                    expert_id: int):
            if "g_idx" in weight_name:
                return
            if not layer.quant_config.has_zp and "qzeros" in weight_name:
                return

            device = get_tp_group().device
            tp_rank = get_tensor_model_parallel_rank()
            loaded_weight = loaded_weight.to(device)
            shard_size = layer.intermediate_size_per_partition

            # convert gptq and awq weight to a standard format
            if layer.quant_config.linear_quant_method == "awq":
                assert layer.quant_config.weight_bits == 4
                if "weight" in weight_name:
                    loaded_weight = convert_awq_tensor(loaded_weight,
                                                       "qweight")
                elif "zeros" in weight_name:
                    loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
                else:
                    loaded_weight = loaded_weight.T
            elif layer.quant_config.linear_quant_method == "gptq":
                assert layer.quant_config.weight_bits in [4, 8]
                if "weight" in weight_name:
                    loaded_weight = loaded_weight.T.contiguous().view(
                        torch.uint8)
                elif "zeros" in weight_name:
                    # add 1 to gptq qzeros to align with awq
                    loaded_weight = loaded_weight.view(torch.uint8)
                    if layer.quant_config.weight_bits == 4:
                        loaded_weight = convert_gptq_int4_qzeros(
                            loaded_weight).T
                    else:
                        loaded_weight = loaded_weight.T + 1
                else:
                    loaded_weight = loaded_weight.T

            # repeat the qzeros/scales to fit new group size
            if layer.group_size_div_factor > 1 and \
                    "qzeros" in weight_name or "scales" in weight_name:
                loaded_weight = loaded_weight.repeat_interleave(
                    layer.group_size_div_factor, 1)

            if "w13_qzeros" in weight_name:
                tensor = loaded_weight.view(layer.tp_size, -1,
                                            loaded_weight.size(1))[tp_rank]
                if shard_id == "w1":
                    param.data[expert_id, :shard_size // 2] = tensor
                else:
                    param.data[expert_id, shard_size // 2:] = tensor
            elif "w2_qzeros" in weight_name:
                param.data[expert_id] = loaded_weight.view(
                    loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
            else:
                weight_loader(param, loaded_weight, weight_name, shard_id,
                              expert_id)

        return moe_wna16_weight_loader

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: MoeWNA16Config)
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
def __init__(self, quant_config: MoeWNA16Config):
    self.quant_config = quant_config

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `MoeWNA16Method` yet.")

    from vllm.model_executor.layers.fused_moe import fused_experts
    assert activation == "silu", "Only SiLU activation is supported."
    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias)

    weight_bits = self.quant_config.weight_bits
    has_zp = self.quant_config.has_zp

    return fused_experts(
        x,
        layer.w13_qweight,
        layer.w2_qweight,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=True,
        use_int4_w4a16=weight_bits == 4,
        use_int8_w8a16=weight_bits == 8,
        global_num_experts=global_num_experts,
        apply_router_weight_on_input=apply_router_weight_on_input,
        expert_map=expert_map,
        w1_scale=layer.w13_scales,
        w2_scale=layer.w2_scales,
        w1_zp=layer.w13_qzeros if has_zp else None,
        w2_zp=layer.w2_qzeros if has_zp else None,
        block_shape=[0, layer.group_size])

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
def create_weights(self, layer: torch.nn.Module, num_experts: int,
                   hidden_size: int, intermediate_size_per_partition: int,
                   params_dtype: torch.dtype, **extra_weight_attrs):

    layer.quant_config = self.quant_config
    bit8_pack_factor = self.quant_config.bit8_pack_factor
    group_size = self.quant_config.group_size
    group_size_div_factor = 1

    # make intermediate_size and hidden_size diviable by group_size
    # we reduce the group size to ensure that
    # and we would repeat the loaded_weight later
    while intermediate_size_per_partition % group_size or \
            hidden_size % group_size:
        group_size = group_size // 2
        group_size_div_factor *= 2
        assert group_size >= 32
    layer.group_size = group_size
    layer.group_size_div_factor = group_size_div_factor

    strategy = FusedMoeWeightScaleSupported.GROUP.value
    extra_weight_attrs.update({
        "quant_method": strategy,
        "is_transposed": False
    })

    assert 'weight_loader' in extra_weight_attrs
    weight_loader = extra_weight_attrs['weight_loader']
    wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
        layer, weight_loader)
    extra_weight_attrs['weight_loader'] = wrapped_weight_loader

    # Fused gate_up_proj (column parallel)
    w13_qweight = torch.nn.Parameter(torch.empty(
        num_experts,
        2 * intermediate_size_per_partition,
        hidden_size // bit8_pack_factor,
        dtype=torch.uint8),
                                     requires_grad=False)
    layer.register_parameter("w13_qweight", w13_qweight)
    set_weight_attrs(w13_qweight, extra_weight_attrs)

    # down_proj (row parallel)
    w2_qweight = torch.nn.Parameter(torch.empty(
        num_experts,
        hidden_size,
        intermediate_size_per_partition // bit8_pack_factor,
        dtype=torch.uint8),
                                    requires_grad=False)
    layer.register_parameter("w2_qweight", w2_qweight)
    set_weight_attrs(w2_qweight, extra_weight_attrs)

    w13_scales = torch.nn.Parameter(torch.zeros(
        num_experts,
        2 * intermediate_size_per_partition,
        hidden_size // group_size,
        dtype=params_dtype),
                                    requires_grad=False)
    layer.register_parameter("w13_scales", w13_scales)
    set_weight_attrs(w13_scales, extra_weight_attrs)

    w2_scales = torch.nn.Parameter(torch.zeros(
        num_experts,
        hidden_size,
        intermediate_size_per_partition // group_size,
        dtype=params_dtype),
                                   requires_grad=False)
    layer.register_parameter("w2_scales", w2_scales)
    set_weight_attrs(w2_scales, extra_weight_attrs)

    if self.quant_config.has_zp:
        w13_qzeros = torch.nn.Parameter(torch.zeros(
            num_experts,
            2 * intermediate_size_per_partition // bit8_pack_factor,
            hidden_size // group_size,
            dtype=torch.uint8),
                                        requires_grad=False)
        layer.register_parameter("w13_qzeros", w13_qzeros)
        set_weight_attrs(w13_qzeros, extra_weight_attrs)

        w2_qzeros = torch.nn.Parameter(torch.zeros(
            num_experts,
            hidden_size // bit8_pack_factor,
            intermediate_size_per_partition // group_size,
            dtype=torch.uint8),
                                       requires_grad=False)
        layer.register_parameter("w2_qzeros", w2_qzeros)
        set_weight_attrs(w2_qzeros, extra_weight_attrs)

    if self.quant_config.linear_quant_method == "gptq":
        # some param are unused, but we need to init them in order to
        # load weights
        invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
        if not self.quant_config.has_zp:
            invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
        for key in invalid_param_keys:
            param = torch.nn.Parameter(torch.empty((0, ),
                                                   dtype=torch.int32),
                                       requires_grad=False)
            layer.register_parameter(key, param)
            set_weight_attrs(param, extra_weight_attrs)

get_weight_loader staticmethod

get_weight_loader(layer, weight_loader)
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
@staticmethod
def get_weight_loader(layer, weight_loader):

    def convert_awq_tensor(tensor, tensor_type):
        # convert awq qweight/qzeros to a standard format (assume int4)
        # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
        # qzeros: (k // group_size, n // pack_factor_bit32) ->
        #         (n // pack_factor_bit8, k // group_size)
        # pack_factor_bit32 = 32 // weight_bits
        # pack_factor_bit8 = 8 // weight_bits

        # 0. suppose origin shape (a, b), dtype int32
        # 1. convert to uint8, shape (a, b) -> (a, 4 * b)
        size0 = tensor.size(0)
        tensor = tensor.view(torch.uint8)

        # 2. unpack to uint4 (only when weight_bits == 4)
        #    shape (a, 4 * b) -> (a, 4 * b, 2)
        shifter = torch.tensor([0, 4],
                               dtype=torch.uint8,
                               device=tensor.device)
        tensor = (tensor[:, :, None] >> shifter) & 0xF

        # 3. change order, see
        # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
        # shape -> (a, 4 * b * pack_factor_bit8)
        reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
        tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
        tensor = tensor.view(size0, -1)

        # 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
        tensor = tensor.T.contiguous()

        # 5. repack (only when weight_bits == 4)
        # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
        # qzeros shape -> (4 * b, a)

        if tensor_type == "qweight":
            tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
        elif tensor_type == "qzeros":
            tensor = tensor[1::2, :] * 16 + tensor[::2, :]
        return tensor

    def convert_gptq_int4_qzeros(tensor):
        tensor = tensor.view(torch.uint8)
        shifter = torch.tensor([0, 4],
                               dtype=torch.uint8,
                               device=tensor.device)
        tensor = (tensor[:, :, None] >> shifter) & 0xF
        tensor = tensor + 1
        tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
        return tensor

    def moe_wna16_weight_loader(param: torch.nn.Parameter,
                                loaded_weight: torch.Tensor,
                                weight_name: str, shard_id: str,
                                expert_id: int):
        if "g_idx" in weight_name:
            return
        if not layer.quant_config.has_zp and "qzeros" in weight_name:
            return

        device = get_tp_group().device
        tp_rank = get_tensor_model_parallel_rank()
        loaded_weight = loaded_weight.to(device)
        shard_size = layer.intermediate_size_per_partition

        # convert gptq and awq weight to a standard format
        if layer.quant_config.linear_quant_method == "awq":
            assert layer.quant_config.weight_bits == 4
            if "weight" in weight_name:
                loaded_weight = convert_awq_tensor(loaded_weight,
                                                   "qweight")
            elif "zeros" in weight_name:
                loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
            else:
                loaded_weight = loaded_weight.T
        elif layer.quant_config.linear_quant_method == "gptq":
            assert layer.quant_config.weight_bits in [4, 8]
            if "weight" in weight_name:
                loaded_weight = loaded_weight.T.contiguous().view(
                    torch.uint8)
            elif "zeros" in weight_name:
                # add 1 to gptq qzeros to align with awq
                loaded_weight = loaded_weight.view(torch.uint8)
                if layer.quant_config.weight_bits == 4:
                    loaded_weight = convert_gptq_int4_qzeros(
                        loaded_weight).T
                else:
                    loaded_weight = loaded_weight.T + 1
            else:
                loaded_weight = loaded_weight.T

        # repeat the qzeros/scales to fit new group size
        if layer.group_size_div_factor > 1 and \
                "qzeros" in weight_name or "scales" in weight_name:
            loaded_weight = loaded_weight.repeat_interleave(
                layer.group_size_div_factor, 1)

        if "w13_qzeros" in weight_name:
            tensor = loaded_weight.view(layer.tp_size, -1,
                                        loaded_weight.size(1))[tp_rank]
            if shard_id == "w1":
                param.data[expert_id, :shard_size // 2] = tensor
            else:
                param.data[expert_id, shard_size // 2:] = tensor
        elif "w2_qzeros" in weight_name:
            param.data[expert_id] = loaded_weight.view(
                loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
        else:
            weight_loader(param, loaded_weight, weight_name, shard_id,
                          expert_id)

    return moe_wna16_weight_loader

is_layer_skipped_quant

is_layer_skipped_quant(
    prefix: str, modules_to_not_convert: list[str]
)
Source code in vllm/model_executor/layers/quantization/moe_wna16.py
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]):
    return any(module_name in prefix for module_name in modules_to_not_convert)