Skip to content

vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe

NVFP4 quantization emulation for MoE.

This file implements NVFP4 emulation for NVFP4 MOE in case the hardware used does not natively support NVFP4 MOE.

Weights are dequantized on the fly during each forward, we fall back to calling TritonExperts using BF16, and fake NVFP4 quantize-dequantize is applied on a13, a2.

Nvfp4QuantizationEmulationTritonExperts

Bases: TritonExperts

Extension of TritonExperts to support emulated NVFP4 MoE experts.

It may be used for NVFP4 models when the device does not have native support for this dtype.

Source code in vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py
class Nvfp4QuantizationEmulationTritonExperts(TritonExperts):
    """
    Extension of TritonExperts to support emulated NVFP4 MoE experts.

    It may be used for NVFP4 models when the device does not have
    native support for this dtype.
    """

    def __init__(
        self,
        moe_config: FusedMoEConfig,
        quant_config: FusedMoEQuantConfig,
    ):
        super().__init__(moe_config, quant_config)
        logger.warning_once(
            "Using Nvfp4QuantizationEmulationTritonExperts MOE backend. This will"
            " dequantize weights on the fly and may be slower than native"
            " quantized MOE. Consider using a device with native quantization"
            " support (e.g. Nvidia Blackwell) for better performance."
        )

        # `TritonExperts.apply` expects pre-dequantized weights,
        # which we handle in `apply` below.
        self.w1_scale_val = self.quant_config.w1_scale
        self.w2_scale_val = self.quant_config.w2_scale

        self.quant_config._w1.scale = None
        self.quant_config._w2.scale = None

        self.quantization_emulation = True

    @property
    def quant_dtype(self) -> torch.dtype | str | None:
        return "nvfp4"

    @property
    def expects_unquantized_inputs(self) -> bool:
        return True

    @staticmethod
    def _supports_quant_scheme(
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
    ) -> bool:
        return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: MoEActivation,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        """
        Apply emulated quantized MoE computation.

        This dequantizes the weights on the fly and calls fused_experts_impl
        with activation quantization support.
        """
        # Dequantize weights if they are quantized
        # For NVFP4, weights are packed in uint8 format
        # w1 shape: [num_experts, 2*intermediate_size, hidden_size//2]
        # w2 shape: [num_experts, hidden_size, intermediate_size//2]
        assert w1.dtype == torch.uint8
        assert w2.dtype == torch.uint8

        # Dequantize w1 from packed NVFP4 to fp16/bf16
        w13_global_scale = self.quant_config.g1_alphas

        w1_dequant = dequantize_to_dtype(
            tensor_fp4=w1,
            tensor_sf=self.w1_scale_val,
            global_scale=w13_global_scale,
            dtype=hidden_states.dtype,
            block_size=16,
            swizzle=False,
        )

        # Dequantize w2 from packed NVFP4 to fp16/bf16
        w2_global_scale = self.quant_config.g2_alphas

        w2_dequant = dequantize_to_dtype(
            tensor_fp4=w2,
            tensor_sf=self.w2_scale_val,
            global_scale=w2_global_scale,
            dtype=hidden_states.dtype,
            block_size=16,
            swizzle=False,
        )

        hidden_states, _ = moe_kernel_quantize_input(
            A=hidden_states,
            A_scale=self.quant_config.a1_gscale,
            quant_dtype="nvfp4",
            per_act_token_quant=False,
            quantization_emulation=True,
        )

        # Activation quantization/dequantization is deferred to
        # `moe_kernel_quantize_input` in TritonExperts.apply.
        super().apply(
            output=output,
            hidden_states=hidden_states,
            w1=w1_dequant,
            w2=w2_dequant,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
            a1q_scale=None,
            a2_scale=self.quant_config.a2_gscale,
            workspace13=workspace13,
            workspace2=workspace2,
            expert_tokens_meta=expert_tokens_meta,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

apply

apply(
    output: Tensor,
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: MoEActivation,
    global_num_experts: int,
    expert_map: Tensor | None,
    a1q_scale: Tensor | None,
    a2_scale: Tensor | None,
    workspace13: Tensor,
    workspace2: Tensor,
    expert_tokens_meta: ExpertTokensMetadata | None,
    apply_router_weight_on_input: bool,
)

Apply emulated quantized MoE computation.

This dequantizes the weights on the fly and calls fused_experts_impl with activation quantization support.

Source code in vllm/model_executor/layers/fused_moe/experts/nvfp4_emulation_moe.py
def apply(
    self,
    output: torch.Tensor,
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: MoEActivation,
    global_num_experts: int,
    expert_map: torch.Tensor | None,
    a1q_scale: torch.Tensor | None,
    a2_scale: torch.Tensor | None,
    workspace13: torch.Tensor,
    workspace2: torch.Tensor,
    expert_tokens_meta: mk.ExpertTokensMetadata | None,
    apply_router_weight_on_input: bool,
):
    """
    Apply emulated quantized MoE computation.

    This dequantizes the weights on the fly and calls fused_experts_impl
    with activation quantization support.
    """
    # Dequantize weights if they are quantized
    # For NVFP4, weights are packed in uint8 format
    # w1 shape: [num_experts, 2*intermediate_size, hidden_size//2]
    # w2 shape: [num_experts, hidden_size, intermediate_size//2]
    assert w1.dtype == torch.uint8
    assert w2.dtype == torch.uint8

    # Dequantize w1 from packed NVFP4 to fp16/bf16
    w13_global_scale = self.quant_config.g1_alphas

    w1_dequant = dequantize_to_dtype(
        tensor_fp4=w1,
        tensor_sf=self.w1_scale_val,
        global_scale=w13_global_scale,
        dtype=hidden_states.dtype,
        block_size=16,
        swizzle=False,
    )

    # Dequantize w2 from packed NVFP4 to fp16/bf16
    w2_global_scale = self.quant_config.g2_alphas

    w2_dequant = dequantize_to_dtype(
        tensor_fp4=w2,
        tensor_sf=self.w2_scale_val,
        global_scale=w2_global_scale,
        dtype=hidden_states.dtype,
        block_size=16,
        swizzle=False,
    )

    hidden_states, _ = moe_kernel_quantize_input(
        A=hidden_states,
        A_scale=self.quant_config.a1_gscale,
        quant_dtype="nvfp4",
        per_act_token_quant=False,
        quantization_emulation=True,
    )

    # Activation quantization/dequantization is deferred to
    # `moe_kernel_quantize_input` in TritonExperts.apply.
    super().apply(
        output=output,
        hidden_states=hidden_states,
        w1=w1_dequant,
        w2=w2_dequant,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        activation=activation,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
        a1q_scale=None,
        a2_scale=self.quant_config.a2_gscale,
        workspace13=workspace13,
        workspace2=workspace2,
        expert_tokens_meta=expert_tokens_meta,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )