Skip to content

vllm.model_executor.layers.quantization.gptq_marlin

logger module-attribute

logger = init_logger(__name__)

GPTQMarlinConfig

Bases: QuantizationConfig

Config class for GPTQ Marlin

Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
class GPTQMarlinConfig(QuantizationConfig):
    """Config class for GPTQ Marlin"""

    # (num_bits, is_sym) -> quant_type
    TYPE_MAP = {
        (4, True): scalar_types.uint4b8,
        (8, True): scalar_types.uint8b128,
    }

    def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
                 is_sym: bool, lm_head_quantized: bool,
                 dynamic: dict[str, dict[str, Union[int, bool]]],
                 full_config: dict[str, Any]) -> None:
        super().__init__()
        if desc_act and group_size == -1:
            # In this case, act_order == True is the same as act_order == False
            # (since we have only one group per output channel)
            desc_act = False

        # GPTQModel use `dynamic` config property to allow per module
        # quantization config so each module can be individually optimized.
        # Format is dict[str, dict] where key is a regex string that can
        # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
        # matching of a module.
        # Default to positive match, override base quant config mode, if no
        # prefix is used. Value is in dict format of field key and override
        # value.
        # Negative matching will skip quantization init for this module
        # entirely:
        # non-quantized inference. More details and quantization examples can be
        # found at: https://github.com/ModelCloud/GPTQModel
        # Example:
        #  # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
        #  # last 1/4 of the layers 16-21 has 8bit and group_size 64
        # dynamic = {
        #  #`.*\.` matches the layers_node prefix
        #  # positive match layer 10-15
        #  r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
        #  # positive match layer 16-21
        #  r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
        #  r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
        # }
        self.dynamic = dynamic

        self.weight_bits = weight_bits
        self.is_sym = is_sym

        self.pack_factor = 32 // weight_bits  # packed into int32
        self.group_size = group_size
        self.desc_act = desc_act
        self.lm_head_quantized = lm_head_quantized
        self.full_config = full_config

        if (weight_bits, is_sym) not in self.TYPE_MAP:
            raise ValueError("Unsupported quantization config: "
                             f"bits={weight_bits}, sym={is_sym}")

        self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]

    def __repr__(self) -> str:
        return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
                f"group_size={self.group_size}, "
                f"desc_act={self.desc_act}, "
                f"lm_head_quantized={self.lm_head_quantized}), "
                f"dynamic={self.dynamic}")

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "gptq_marlin"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.half, torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return ["quantize_config.json"]

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
        dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
        dynamic = {} if dynamic is None else dynamic

        weight_bits = cls.get_from_keys(config, ["bits"])
        group_size = cls.get_from_keys(config, ["group_size"])
        desc_act = cls.get_from_keys(config, ["desc_act"])
        is_sym = cls.get_from_keys(config, ["sym"])
        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
                                                 default=False)
        return cls(weight_bits, group_size, desc_act, is_sym,
                   lm_head_quantized, dynamic, config)

    @classmethod
    def override_quantization_method(
            cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
        can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)

        is_valid_user_quant = (user_quant is None or user_quant == "marlin"
                               or user_quant == "gptq_marlin")

        if can_convert and is_valid_user_quant:
            msg = ("The model is convertible to {} during runtime."
                   " Using {} kernel.".format(cls.get_name(), cls.get_name()))
            logger.info(msg)
            return cls.get_name()

        if can_convert and user_quant == "gptq":
            logger.info("Detected that the model can run with gptq_marlin"
                        ", however you specified quantization=gptq explicitly,"
                        " so forcing gptq. Use quantization=gptq_marlin for"
                        " faster inference")
        return None

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        if isinstance(layer, FusedMoE):
            from vllm.model_executor.layers.quantization.moe_wna16 import (
                MoeWNA16Config)
            if not check_moe_marlin_supports_layer(layer, self.group_size):
                logger.warning_once(
                    f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
                    "Falling back to Moe WNA16 kernels.")
                return MoeWNA16Config.from_config(
                    self.full_config).get_quant_method(layer, prefix)
            return get_moe_quant_method(self, layer, prefix,
                                        GPTQMarlinMoEMethod)
        return get_linear_quant_method(self, layer, prefix,
                                       GPTQMarlinLinearMethod)

    @classmethod
    def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
        quant_method = quant_config.get("quant_method", "").lower()
        num_bits = quant_config.get("bits")
        group_size = quant_config.get("group_size")
        sym = quant_config.get("sym")
        desc_act = quant_config.get("desc_act")

        if not current_platform.is_cuda():
            return False

        if quant_method != "gptq":
            return False

        # Marlin conversion is only valid if required properties are found
        if (num_bits is None or group_size is None or sym is None
                or desc_act is None):
            return False

        if (num_bits, sym) not in cls.TYPE_MAP:
            return False

        return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
                                      group_size=group_size)

TYPE_MAP class-attribute instance-attribute

TYPE_MAP = {(4, True): uint4b8, (8, True): uint8b128}

desc_act instance-attribute

desc_act = desc_act

dynamic instance-attribute

dynamic = dynamic

full_config instance-attribute

full_config = full_config

group_size instance-attribute

group_size = group_size

is_sym instance-attribute

is_sym = is_sym

lm_head_quantized instance-attribute

lm_head_quantized = lm_head_quantized

pack_factor instance-attribute

pack_factor = 32 // weight_bits

quant_type instance-attribute

quant_type = TYPE_MAP[weight_bits, is_sym]

weight_bits instance-attribute

weight_bits = weight_bits

__init__

__init__(
    weight_bits: int,
    group_size: int,
    desc_act: bool,
    is_sym: bool,
    lm_head_quantized: bool,
    dynamic: dict[str, dict[str, Union[int, bool]]],
    full_config: dict[str, Any],
) -> None
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
             is_sym: bool, lm_head_quantized: bool,
             dynamic: dict[str, dict[str, Union[int, bool]]],
             full_config: dict[str, Any]) -> None:
    super().__init__()
    if desc_act and group_size == -1:
        # In this case, act_order == True is the same as act_order == False
        # (since we have only one group per output channel)
        desc_act = False

    # GPTQModel use `dynamic` config property to allow per module
    # quantization config so each module can be individually optimized.
    # Format is dict[str, dict] where key is a regex string that can
    # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
    # matching of a module.
    # Default to positive match, override base quant config mode, if no
    # prefix is used. Value is in dict format of field key and override
    # value.
    # Negative matching will skip quantization init for this module
    # entirely:
    # non-quantized inference. More details and quantization examples can be
    # found at: https://github.com/ModelCloud/GPTQModel
    # Example:
    #  # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
    #  # last 1/4 of the layers 16-21 has 8bit and group_size 64
    # dynamic = {
    #  #`.*\.` matches the layers_node prefix
    #  # positive match layer 10-15
    #  r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
    #  # positive match layer 16-21
    #  r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
    #  r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
    # }
    self.dynamic = dynamic

    self.weight_bits = weight_bits
    self.is_sym = is_sym

    self.pack_factor = 32 // weight_bits  # packed into int32
    self.group_size = group_size
    self.desc_act = desc_act
    self.lm_head_quantized = lm_head_quantized
    self.full_config = full_config

    if (weight_bits, is_sym) not in self.TYPE_MAP:
        raise ValueError("Unsupported quantization config: "
                         f"bits={weight_bits}, sym={is_sym}")

    self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]

__repr__

__repr__() -> str
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def __repr__(self) -> str:
    return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
            f"group_size={self.group_size}, "
            f"desc_act={self.desc_act}, "
            f"lm_head_quantized={self.lm_head_quantized}), "
            f"dynamic={self.dynamic}")

from_config classmethod

from_config(config: dict[str, Any]) -> GPTQMarlinConfig
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig":
    dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
    dynamic = {} if dynamic is None else dynamic

    weight_bits = cls.get_from_keys(config, ["bits"])
    group_size = cls.get_from_keys(config, ["group_size"])
    desc_act = cls.get_from_keys(config, ["desc_act"])
    is_sym = cls.get_from_keys(config, ["sym"])
    lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
                                             default=False)
    return cls(weight_bits, group_size, desc_act, is_sym,
               lm_head_quantized, dynamic, config)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return ["quantize_config.json"]

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "gptq_marlin"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    if isinstance(layer, FusedMoE):
        from vllm.model_executor.layers.quantization.moe_wna16 import (
            MoeWNA16Config)
        if not check_moe_marlin_supports_layer(layer, self.group_size):
            logger.warning_once(
                f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
                "Falling back to Moe WNA16 kernels.")
            return MoeWNA16Config.from_config(
                self.full_config).get_quant_method(layer, prefix)
        return get_moe_quant_method(self, layer, prefix,
                                    GPTQMarlinMoEMethod)
    return get_linear_quant_method(self, layer, prefix,
                                   GPTQMarlinLinearMethod)

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.half, torch.bfloat16]

is_gptq_marlin_compatible classmethod

is_gptq_marlin_compatible(quant_config: dict[str, Any])
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
    quant_method = quant_config.get("quant_method", "").lower()
    num_bits = quant_config.get("bits")
    group_size = quant_config.get("group_size")
    sym = quant_config.get("sym")
    desc_act = quant_config.get("desc_act")

    if not current_platform.is_cuda():
        return False

    if quant_method != "gptq":
        return False

    # Marlin conversion is only valid if required properties are found
    if (num_bits is None or group_size is None or sym is None
            or desc_act is None):
        return False

    if (num_bits, sym) not in cls.TYPE_MAP:
        return False

    return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
                                  group_size=group_size)

override_quantization_method classmethod

override_quantization_method(
    hf_quant_cfg, user_quant
) -> Optional[QuantizationMethods]
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
@classmethod
def override_quantization_method(
        cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
    can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)

    is_valid_user_quant = (user_quant is None or user_quant == "marlin"
                           or user_quant == "gptq_marlin")

    if can_convert and is_valid_user_quant:
        msg = ("The model is convertible to {} during runtime."
               " Using {} kernel.".format(cls.get_name(), cls.get_name()))
        logger.info(msg)
        return cls.get_name()

    if can_convert and user_quant == "gptq":
        logger.info("Detected that the model can run with gptq_marlin"
                    ", however you specified quantization=gptq explicitly,"
                    " so forcing gptq. Use quantization=gptq_marlin for"
                    " faster inference")
    return None

GPTQMarlinLinearMethod

Bases: LinearMethodBase

Linear method for GPTQ Marlin.

Parameters:

Name Type Description Default
quant_config GPTQMarlinConfig

The GPTQ Marlin quantization config.

required
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
class GPTQMarlinLinearMethod(LinearMethodBase):
    """Linear method for GPTQ Marlin.

    Args:
        quant_config: The GPTQ Marlin quantization config.
    """

    _kernel_backends_being_used: set[str] = set()

    def __init__(self, quant_config: GPTQMarlinConfig) -> None:
        self.quant_config = quant_config

        # Verify supported on platform.
        verify_marlin_supported(quant_type=self.quant_config.quant_type,
                                group_size=self.quant_config.group_size)

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ) -> None:
        output_size_per_partition = sum(output_partition_sizes)
        is_row_parallel = input_size != input_size_per_partition
        weight_loader = extra_weight_attrs.get("weight_loader")

        mp_linear_kernel_config = MPLinearLayerConfig(
            full_weight_shape=(input_size, output_size),
            partition_weight_shape=\
                (input_size_per_partition, output_size_per_partition),
            weight_type=self.quant_config.quant_type,
            act_type=params_dtype,
            group_size=self.quant_config.group_size,
            zero_points=False,
            has_g_idx=self.quant_config.desc_act
        )

        kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

        if kernel_type.__name__ not in self._kernel_backends_being_used:
            logger.info("Using %s for GPTQMarlinLinearMethod",
                        kernel_type.__name__)
            self._kernel_backends_being_used.add(kernel_type.__name__)

        # Normalize group_size
        if self.quant_config.group_size != -1:
            group_size = self.quant_config.group_size
        else:
            group_size = input_size

        # Determine sharding
        if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
                                             self.quant_config.group_size,
                                             is_row_parallel):
            # By setting scale_dim == None, weight_loader will
            # repeat the scales on each GPU in TP>1 case.
            scales_and_zp_input_dim = None
            scales_and_zp_size = input_size // group_size
        else:
            # By setting scale_dim == 0, weight_loader will
            # shard the scales in TP>1 case.
            scales_and_zp_input_dim = 0
            scales_and_zp_size = input_size_per_partition // group_size

        # Quantized weights
        qweight = PackedvLLMParameter(
            data=torch.empty(
                input_size_per_partition // self.quant_config.pack_factor,
                output_size_per_partition,
                dtype=torch.int32,
            ),
            input_dim=0,
            output_dim=1,
            packed_dim=0,
            packed_factor=self.quant_config.pack_factor,
            weight_loader=weight_loader)

        # Activation order
        g_idx = RowvLLMParameter(data=torch.empty(
            input_size_per_partition,
            dtype=torch.int32,
        ),
                                 input_dim=0,
                                 weight_loader=weight_loader)

        qzeros_args = {
            "data":
            torch.empty(
                scales_and_zp_size,
                output_size_per_partition // self.quant_config.pack_factor,
                dtype=torch.int32,
            ),
            "weight_loader":
            weight_loader
        }
        weight_scale_args = {
            "data":
            torch.empty(
                scales_and_zp_size,
                output_size_per_partition,
                dtype=params_dtype,
            ),
            "weight_loader":
            weight_loader
        }

        if scales_and_zp_input_dim is None:
            scales = ChannelQuantScaleParameter(output_dim=1,
                                                **weight_scale_args)
            qzeros = PackedColumnParameter(
                output_dim=1,
                packed_dim=1,
                packed_factor=self.quant_config.pack_factor,
                **qzeros_args)

        else:
            scales = GroupQuantScaleParameter(output_dim=1,
                                              input_dim=0,
                                              **weight_scale_args)
            qzeros = PackedvLLMParameter(
                input_dim=0,
                output_dim=1,
                packed_dim=1,
                packed_factor=self.quant_config.pack_factor,
                **qzeros_args)

        layer.register_parameter("qweight", qweight)
        layer.register_parameter("g_idx", g_idx)
        layer.register_parameter("scales", scales)
        layer.register_parameter("qzeros", qzeros)

        self.kernel = kernel_type(mp_linear_kernel_config,
                                  w_q_param_name="qweight",
                                  w_s_param_name="scales",
                                  w_zp_param_name="qzeros",
                                  w_gidx_param_name="g_idx")

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        self.kernel.process_weights_after_loading(layer)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.kernel.apply_weights(layer, x, bias)

_kernel_backends_being_used class-attribute instance-attribute

_kernel_backends_being_used: set[str] = set()

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: GPTQMarlinConfig) -> None
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
    self.quant_config = quant_config

    # Verify supported on platform.
    verify_marlin_supported(quant_type=self.quant_config.quant_type,
                            group_size=self.quant_config.group_size)

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    return self.kernel.apply_weights(layer, x, bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
) -> None
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
) -> None:
    output_size_per_partition = sum(output_partition_sizes)
    is_row_parallel = input_size != input_size_per_partition
    weight_loader = extra_weight_attrs.get("weight_loader")

    mp_linear_kernel_config = MPLinearLayerConfig(
        full_weight_shape=(input_size, output_size),
        partition_weight_shape=\
            (input_size_per_partition, output_size_per_partition),
        weight_type=self.quant_config.quant_type,
        act_type=params_dtype,
        group_size=self.quant_config.group_size,
        zero_points=False,
        has_g_idx=self.quant_config.desc_act
    )

    kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

    if kernel_type.__name__ not in self._kernel_backends_being_used:
        logger.info("Using %s for GPTQMarlinLinearMethod",
                    kernel_type.__name__)
        self._kernel_backends_being_used.add(kernel_type.__name__)

    # Normalize group_size
    if self.quant_config.group_size != -1:
        group_size = self.quant_config.group_size
    else:
        group_size = input_size

    # Determine sharding
    if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
                                         self.quant_config.group_size,
                                         is_row_parallel):
        # By setting scale_dim == None, weight_loader will
        # repeat the scales on each GPU in TP>1 case.
        scales_and_zp_input_dim = None
        scales_and_zp_size = input_size // group_size
    else:
        # By setting scale_dim == 0, weight_loader will
        # shard the scales in TP>1 case.
        scales_and_zp_input_dim = 0
        scales_and_zp_size = input_size_per_partition // group_size

    # Quantized weights
    qweight = PackedvLLMParameter(
        data=torch.empty(
            input_size_per_partition // self.quant_config.pack_factor,
            output_size_per_partition,
            dtype=torch.int32,
        ),
        input_dim=0,
        output_dim=1,
        packed_dim=0,
        packed_factor=self.quant_config.pack_factor,
        weight_loader=weight_loader)

    # Activation order
    g_idx = RowvLLMParameter(data=torch.empty(
        input_size_per_partition,
        dtype=torch.int32,
    ),
                             input_dim=0,
                             weight_loader=weight_loader)

    qzeros_args = {
        "data":
        torch.empty(
            scales_and_zp_size,
            output_size_per_partition // self.quant_config.pack_factor,
            dtype=torch.int32,
        ),
        "weight_loader":
        weight_loader
    }
    weight_scale_args = {
        "data":
        torch.empty(
            scales_and_zp_size,
            output_size_per_partition,
            dtype=params_dtype,
        ),
        "weight_loader":
        weight_loader
    }

    if scales_and_zp_input_dim is None:
        scales = ChannelQuantScaleParameter(output_dim=1,
                                            **weight_scale_args)
        qzeros = PackedColumnParameter(
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            **qzeros_args)

    else:
        scales = GroupQuantScaleParameter(output_dim=1,
                                          input_dim=0,
                                          **weight_scale_args)
        qzeros = PackedvLLMParameter(
            input_dim=0,
            output_dim=1,
            packed_dim=1,
            packed_factor=self.quant_config.pack_factor,
            **qzeros_args)

    layer.register_parameter("qweight", qweight)
    layer.register_parameter("g_idx", g_idx)
    layer.register_parameter("scales", scales)
    layer.register_parameter("qzeros", qzeros)

    self.kernel = kernel_type(mp_linear_kernel_config,
                              w_q_param_name="qweight",
                              w_s_param_name="scales",
                              w_zp_param_name="qzeros",
                              w_gidx_param_name="g_idx")

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    self.kernel.process_weights_after_loading(layer)

GPTQMarlinMoEMethod

Bases: FusedMoEMethodBase

MoE Marlin method with quantization.

Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
    """MoE Marlin method with quantization."""

    def __init__(self, quant_config: GPTQMarlinConfig) -> None:
        self.quant_config = quant_config
        if self.quant_config.quant_type.size_bits == 4:
            self.quant_type = scalar_types.uint4b8
        elif self.quant_config.quant_type.size_bits == 8:
            self.quant_type = scalar_types.uint8b128
        else:
            raise ValueError(
                "GPTQMarlinMoEMethod only supports int4 and int8 now.")

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts: int,
        hidden_size: int,
        intermediate_size_per_partition: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        intermediate_size_full = extra_weight_attrs.pop(
            "intermediate_size_full")

        self.is_k_full = (not self.quant_config.desc_act) or (
            intermediate_size_per_partition == intermediate_size_full)

        if self.quant_config.group_size != -1:
            scales_size13 = hidden_size // self.quant_config.group_size
            w2_scales_size = (intermediate_size_full
                              if self.quant_config.desc_act else
                              intermediate_size_per_partition)
            scales_size2 = (w2_scales_size // self.quant_config.group_size)
            strategy = FusedMoeWeightScaleSupported.GROUP.value
        else:
            scales_size13 = 1
            scales_size2 = 1
            strategy = FusedMoeWeightScaleSupported.CHANNEL.value

        extra_weight_attrs.update({
            "quant_method": strategy,
            "is_transposed": True
        })
        # Fused gate_up_proj (column parallel)
        w13_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size // self.quant_config.pack_factor,
                2 * intermediate_size_per_partition,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_qweight", w13_qweight)
        set_weight_attrs(w13_qweight, extra_weight_attrs)
        # down_proj (row parallel)
        w2_qweight = torch.nn.Parameter(
            torch.empty(
                num_experts,
                intermediate_size_per_partition //
                self.quant_config.pack_factor,
                hidden_size,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_qweight", w2_qweight)
        set_weight_attrs(w2_qweight, extra_weight_attrs)
        # up_proj scales
        w13_scales = torch.nn.Parameter(
            torch.empty(num_experts,
                        scales_size13,
                        2 * intermediate_size_per_partition,
                        dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w13_scales", w13_scales)
        set_weight_attrs(w13_scales, extra_weight_attrs)
        # down_proj scales
        w2_scales = torch.nn.Parameter(
            torch.empty(num_experts,
                        scales_size2,
                        hidden_size,
                        dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w2_scales", w2_scales)
        set_weight_attrs(w2_scales, extra_weight_attrs)
        # dont shard the w2 scales when running act order
        set_weight_attrs(w2_scales,
                         {"load_full_w2": self.quant_config.desc_act})
        # up_proj scales
        w13_qzeros = torch.nn.Parameter(
            torch.empty(num_experts,
                        scales_size13,
                        2 * intermediate_size_per_partition //
                        self.quant_config.pack_factor,
                        dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w13_qzeros", w13_qzeros)
        set_weight_attrs(w13_qzeros, extra_weight_attrs)
        # down_proj scales
        w2_qzeros = torch.nn.Parameter(
            torch.empty(num_experts,
                        scales_size2,
                        hidden_size // self.quant_config.pack_factor,
                        dtype=params_dtype),
            requires_grad=False,
        )
        layer.register_parameter("w2_qzeros", w2_qzeros)
        set_weight_attrs(w2_qzeros, extra_weight_attrs)
        # dont shard the w2 scales when running act order
        set_weight_attrs(w2_qzeros,
                         {"load_full_w2": self.quant_config.desc_act})
        w13_g_idx = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_g_idx", w13_g_idx)
        set_weight_attrs(w13_g_idx, extra_weight_attrs)
        w2_g_idx = torch.nn.Parameter(
            torch.empty(
                num_experts,
                intermediate_size_per_partition,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_g_idx", w2_g_idx)
        set_weight_attrs(w2_g_idx, extra_weight_attrs)
        w13_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty(
                num_experts,
                hidden_size,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w13_g_idx_sort_indices",
                                 w13_g_idx_sort_indices)
        set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
        w2_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty(
                num_experts,
                intermediate_size_per_partition,
                dtype=torch.int32,
            ),
            requires_grad=False,
        )
        layer.register_parameter("w2_g_idx_sort_indices",
                                 w2_g_idx_sort_indices)
        set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)

        device = layer.w13_qweight.device
        layer.workspace = marlin_make_workspace_new(device, 4)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

        # Process act_order
        if self.quant_config.desc_act:
            # Get sorting based on g_idx
            num_experts = layer.w13_g_idx.shape[0]
            w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
            w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
            w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
            w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
            for e in range(num_experts):
                w13_g_idx_sort_indices[e] = torch.argsort(
                    layer.w13_g_idx[e]).to(torch.int32)
                w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
                    torch.int32)
                w13_sorted_g_idx[e] = layer.w13_g_idx[e][
                    w13_g_idx_sort_indices[e]]
                w2_sorted_g_idx[e] = layer.w2_g_idx[e][
                    w2_g_idx_sort_indices[e]]
            replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
            replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
            replace_parameter(layer, "w13_g_idx_sort_indices",
                              w13_g_idx_sort_indices)
            replace_parameter(layer, "w2_g_idx_sort_indices",
                              w2_g_idx_sort_indices)
        else:
            # Reset g_idx related tensors
            num_experts = layer.w13_g_idx.shape[0]
            device = layer.w13_g_idx.device
            layer.w13_g_idx = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )
            layer.w2_g_idx = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )
            layer.w13_g_idx_sort_indices = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )
            layer.w2_g_idx_sort_indices = torch.nn.Parameter(
                torch.empty((num_experts, 0), dtype=torch.int32,
                            device=device),
                requires_grad=False,
            )
        # Repack weights
        marlin_w13_qweight = ops.gptq_marlin_moe_repack(
            layer.w13_qweight,
            layer.w13_g_idx_sort_indices,
            layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
            layer.w13_qweight.shape[2],
            self.quant_config.quant_type.size_bits,
        )
        replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
        marlin_w2_qweight = ops.gptq_marlin_moe_repack(
            layer.w2_qweight,
            layer.w2_g_idx_sort_indices,
            layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
            layer.w2_qweight.shape[2],
            self.quant_config.quant_type.size_bits,
        )
        replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
        # Repack scales
        marlin_w13_scales = marlin_moe_permute_scales(
            s=layer.w13_scales,
            size_k=layer.intermediate_size_per_partition,
            size_n=layer.w13_scales.shape[2],
            group_size=self.quant_config.group_size,
        )
        replace_parameter(layer, "w13_scales", marlin_w13_scales)
        marlin_w2_scales = marlin_moe_permute_scales(
            s=layer.w2_scales,
            size_k=layer.w2_scales.shape[1] *
            (self.quant_config.group_size if self.quant_config.group_size != -1
             else self.quant_config.pack_factor),
            size_n=layer.w2_scales.shape[2],
            group_size=self.quant_config.group_size,
        )
        replace_parameter(layer, "w2_scales", marlin_w2_scales)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        router_logits: torch.Tensor,
        top_k: int,
        renormalize: bool,
        use_grouped_topk: bool = False,
        topk_group: Optional[int] = None,
        num_expert_group: Optional[int] = None,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
        custom_routing_function: Optional[Callable] = None,
        scoring_func: str = "softmax",
        e_score_correction_bias: Optional[torch.Tensor] = None,
        apply_router_weight_on_input: bool = False,
        activation: str = "silu",
        enable_eplb: bool = False,
        expert_load_view: Optional[torch.Tensor] = None,
        logical_to_physical_map: Optional[torch.Tensor] = None,
        logical_replica_count: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if enable_eplb:
            raise NotImplementedError(
                "EPLB not supported for `GPTQMarlinMoEMethod` yet.")

        assert activation == "silu", "Only SiLU activation is supported."

        topk_weights, topk_ids = FusedMoE.select_experts(
            hidden_states=x,
            router_logits=router_logits,
            use_grouped_topk=use_grouped_topk,
            top_k=top_k,
            renormalize=renormalize,
            topk_group=topk_group,
            num_expert_group=num_expert_group,
            custom_routing_function=custom_routing_function,
            scoring_func=scoring_func,
            e_score_correction_bias=e_score_correction_bias)

        return torch.ops.vllm.fused_marlin_moe(
            x,
            layer.w13_qweight,
            layer.w2_qweight,
            layer.w13_scales,
            layer.w2_scales,
            router_logits,
            topk_weights,
            topk_ids,
            quant_type_id=self.quant_type.id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            g_idx1=layer.w13_g_idx,
            g_idx2=layer.w2_g_idx,
            sort_indices1=layer.w13_g_idx_sort_indices,
            sort_indices2=layer.w2_g_idx_sort_indices,
            workspace=layer.workspace,
            is_k_full=self.is_k_full)

quant_config instance-attribute

quant_config = quant_config

quant_type instance-attribute

quant_type = uint4b8

__init__

__init__(quant_config: GPTQMarlinConfig) -> None
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
    self.quant_config = quant_config
    if self.quant_config.quant_type.size_bits == 4:
        self.quant_type = scalar_types.uint4b8
    elif self.quant_config.quant_type.size_bits == 8:
        self.quant_type = scalar_types.uint8b128
    else:
        raise ValueError(
            "GPTQMarlinMoEMethod only supports int4 and int8 now.")

apply

apply(
    layer: Module,
    x: Tensor,
    router_logits: Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[Tensor] = None,
    logical_to_physical_map: Optional[Tensor] = None,
    logical_replica_count: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    router_logits: torch.Tensor,
    top_k: int,
    renormalize: bool,
    use_grouped_topk: bool = False,
    topk_group: Optional[int] = None,
    num_expert_group: Optional[int] = None,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    custom_routing_function: Optional[Callable] = None,
    scoring_func: str = "softmax",
    e_score_correction_bias: Optional[torch.Tensor] = None,
    apply_router_weight_on_input: bool = False,
    activation: str = "silu",
    enable_eplb: bool = False,
    expert_load_view: Optional[torch.Tensor] = None,
    logical_to_physical_map: Optional[torch.Tensor] = None,
    logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if enable_eplb:
        raise NotImplementedError(
            "EPLB not supported for `GPTQMarlinMoEMethod` yet.")

    assert activation == "silu", "Only SiLU activation is supported."

    topk_weights, topk_ids = FusedMoE.select_experts(
        hidden_states=x,
        router_logits=router_logits,
        use_grouped_topk=use_grouped_topk,
        top_k=top_k,
        renormalize=renormalize,
        topk_group=topk_group,
        num_expert_group=num_expert_group,
        custom_routing_function=custom_routing_function,
        scoring_func=scoring_func,
        e_score_correction_bias=e_score_correction_bias)

    return torch.ops.vllm.fused_marlin_moe(
        x,
        layer.w13_qweight,
        layer.w2_qweight,
        layer.w13_scales,
        layer.w2_scales,
        router_logits,
        topk_weights,
        topk_ids,
        quant_type_id=self.quant_type.id,
        apply_router_weight_on_input=apply_router_weight_on_input,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        g_idx1=layer.w13_g_idx,
        g_idx2=layer.w2_g_idx,
        sort_indices1=layer.w13_g_idx_sort_indices,
        sort_indices2=layer.w2_g_idx_sort_indices,
        workspace=layer.workspace,
        is_k_full=self.is_k_full)

create_weights

create_weights(
    layer: Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def create_weights(
    self,
    layer: torch.nn.Module,
    num_experts: int,
    hidden_size: int,
    intermediate_size_per_partition: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    intermediate_size_full = extra_weight_attrs.pop(
        "intermediate_size_full")

    self.is_k_full = (not self.quant_config.desc_act) or (
        intermediate_size_per_partition == intermediate_size_full)

    if self.quant_config.group_size != -1:
        scales_size13 = hidden_size // self.quant_config.group_size
        w2_scales_size = (intermediate_size_full
                          if self.quant_config.desc_act else
                          intermediate_size_per_partition)
        scales_size2 = (w2_scales_size // self.quant_config.group_size)
        strategy = FusedMoeWeightScaleSupported.GROUP.value
    else:
        scales_size13 = 1
        scales_size2 = 1
        strategy = FusedMoeWeightScaleSupported.CHANNEL.value

    extra_weight_attrs.update({
        "quant_method": strategy,
        "is_transposed": True
    })
    # Fused gate_up_proj (column parallel)
    w13_qweight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size // self.quant_config.pack_factor,
            2 * intermediate_size_per_partition,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_qweight", w13_qweight)
    set_weight_attrs(w13_qweight, extra_weight_attrs)
    # down_proj (row parallel)
    w2_qweight = torch.nn.Parameter(
        torch.empty(
            num_experts,
            intermediate_size_per_partition //
            self.quant_config.pack_factor,
            hidden_size,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_qweight", w2_qweight)
    set_weight_attrs(w2_qweight, extra_weight_attrs)
    # up_proj scales
    w13_scales = torch.nn.Parameter(
        torch.empty(num_experts,
                    scales_size13,
                    2 * intermediate_size_per_partition,
                    dtype=params_dtype),
        requires_grad=False,
    )
    layer.register_parameter("w13_scales", w13_scales)
    set_weight_attrs(w13_scales, extra_weight_attrs)
    # down_proj scales
    w2_scales = torch.nn.Parameter(
        torch.empty(num_experts,
                    scales_size2,
                    hidden_size,
                    dtype=params_dtype),
        requires_grad=False,
    )
    layer.register_parameter("w2_scales", w2_scales)
    set_weight_attrs(w2_scales, extra_weight_attrs)
    # dont shard the w2 scales when running act order
    set_weight_attrs(w2_scales,
                     {"load_full_w2": self.quant_config.desc_act})
    # up_proj scales
    w13_qzeros = torch.nn.Parameter(
        torch.empty(num_experts,
                    scales_size13,
                    2 * intermediate_size_per_partition //
                    self.quant_config.pack_factor,
                    dtype=params_dtype),
        requires_grad=False,
    )
    layer.register_parameter("w13_qzeros", w13_qzeros)
    set_weight_attrs(w13_qzeros, extra_weight_attrs)
    # down_proj scales
    w2_qzeros = torch.nn.Parameter(
        torch.empty(num_experts,
                    scales_size2,
                    hidden_size // self.quant_config.pack_factor,
                    dtype=params_dtype),
        requires_grad=False,
    )
    layer.register_parameter("w2_qzeros", w2_qzeros)
    set_weight_attrs(w2_qzeros, extra_weight_attrs)
    # dont shard the w2 scales when running act order
    set_weight_attrs(w2_qzeros,
                     {"load_full_w2": self.quant_config.desc_act})
    w13_g_idx = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_g_idx", w13_g_idx)
    set_weight_attrs(w13_g_idx, extra_weight_attrs)
    w2_g_idx = torch.nn.Parameter(
        torch.empty(
            num_experts,
            intermediate_size_per_partition,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_g_idx", w2_g_idx)
    set_weight_attrs(w2_g_idx, extra_weight_attrs)
    w13_g_idx_sort_indices = torch.nn.Parameter(
        torch.empty(
            num_experts,
            hidden_size,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w13_g_idx_sort_indices",
                             w13_g_idx_sort_indices)
    set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
    w2_g_idx_sort_indices = torch.nn.Parameter(
        torch.empty(
            num_experts,
            intermediate_size_per_partition,
            dtype=torch.int32,
        ),
        requires_grad=False,
    )
    layer.register_parameter("w2_g_idx_sort_indices",
                             w2_g_idx_sort_indices)
    set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)

    device = layer.w13_qweight.device
    layer.workspace = marlin_make_workspace_new(device, 4)

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

    # Process act_order
    if self.quant_config.desc_act:
        # Get sorting based on g_idx
        num_experts = layer.w13_g_idx.shape[0]
        w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
        w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
        w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
        w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
        for e in range(num_experts):
            w13_g_idx_sort_indices[e] = torch.argsort(
                layer.w13_g_idx[e]).to(torch.int32)
            w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
                torch.int32)
            w13_sorted_g_idx[e] = layer.w13_g_idx[e][
                w13_g_idx_sort_indices[e]]
            w2_sorted_g_idx[e] = layer.w2_g_idx[e][
                w2_g_idx_sort_indices[e]]
        replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
        replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
        replace_parameter(layer, "w13_g_idx_sort_indices",
                          w13_g_idx_sort_indices)
        replace_parameter(layer, "w2_g_idx_sort_indices",
                          w2_g_idx_sort_indices)
    else:
        # Reset g_idx related tensors
        num_experts = layer.w13_g_idx.shape[0]
        device = layer.w13_g_idx.device
        layer.w13_g_idx = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )
        layer.w2_g_idx = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )
        layer.w13_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )
        layer.w2_g_idx_sort_indices = torch.nn.Parameter(
            torch.empty((num_experts, 0), dtype=torch.int32,
                        device=device),
            requires_grad=False,
        )
    # Repack weights
    marlin_w13_qweight = ops.gptq_marlin_moe_repack(
        layer.w13_qweight,
        layer.w13_g_idx_sort_indices,
        layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
        layer.w13_qweight.shape[2],
        self.quant_config.quant_type.size_bits,
    )
    replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
    marlin_w2_qweight = ops.gptq_marlin_moe_repack(
        layer.w2_qweight,
        layer.w2_g_idx_sort_indices,
        layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
        layer.w2_qweight.shape[2],
        self.quant_config.quant_type.size_bits,
    )
    replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
    # Repack scales
    marlin_w13_scales = marlin_moe_permute_scales(
        s=layer.w13_scales,
        size_k=layer.intermediate_size_per_partition,
        size_n=layer.w13_scales.shape[2],
        group_size=self.quant_config.group_size,
    )
    replace_parameter(layer, "w13_scales", marlin_w13_scales)
    marlin_w2_scales = marlin_moe_permute_scales(
        s=layer.w2_scales,
        size_k=layer.w2_scales.shape[1] *
        (self.quant_config.group_size if self.quant_config.group_size != -1
         else self.quant_config.pack_factor),
        size_n=layer.w2_scales.shape[2],
        group_size=self.quant_config.group_size,
    )
    replace_parameter(layer, "w2_scales", marlin_w2_scales)

get_moe_quant_method

get_moe_quant_method(
    config: QuantizationConfig,
    layer: Module,
    prefix: str,
    moe_method_cls: type,
)
Source code in vllm/model_executor/layers/quantization/gptq_marlin.py
def get_moe_quant_method(
    config: QuantizationConfig,
    layer: torch.nn.Module,
    prefix: str,
    moe_method_cls: type,
):
    cloned_config = deepcopy(config)

    if isinstance(layer, FusedMoE):
        # False = skip module, None = no override, else = Positive match
        if get_dynamic_override(  # noqa: E712
                cloned_config,  # noqa: E712
                layer_name=prefix) == False:  # noqa: E712
            return UnquantizedFusedMoEMethod(layer.moe_config)

        if prefix:
            # Dynamic per module/layer rules may override base config
            override_config(cloned_config, prefix=prefix)

        return moe_method_cls(cloned_config)
    return None