Skip to content

vllm.model_executor.layers.quantization.quark.schemes.quark_w4a8_mxfp4_fp8

QuarkW4A8_MXFP4_FP8

Bases: QuarkScheme

  • Weights: MXFP4 with E8M0 scales per block of 32
  • Activations: FP8 E4M3 (static per-tensor quantization)

Uses the AITER Triton kernel and falls back to emulation if AITER not available.

Source code in vllm/model_executor/layers/quantization/quark/schemes/quark_w4a8_mxfp4_fp8.py
class QuarkW4A8_MXFP4_FP8(QuarkScheme):
    """
    - Weights: MXFP4 with E8M0 scales per block of 32
    - Activations: FP8 E4M3 (static per-tensor quantization)

    Uses the AITER Triton kernel and falls back to emulation if AITER not available.
    """

    def __init__(
        self,
        weight_quant_spec: dict[str, Any],
        input_quant_spec: dict[str, Any],
    ):
        self.out_dtype = None

        self.weight_dtype = "mxfp4"
        self.packed_factor: Fraction = Fraction(2, 1)  # 2 FP4 values per byte
        self.weight_block_size = OCP_MX_BLOCK_SIZE

        self.is_static_input_scheme = not input_quant_spec.get("is_dynamic")
        self.input_qscheme = input_quant_spec.get("qscheme")  # "per_tensor"

        self.fp8_min, self.fp8_max = get_fp8_min_max()
        self.fp8_dtype = current_platform.fp8_dtype()

        if not self.is_static_input_scheme:
            raise NotImplementedError(
                "Dynamic FP8 activation quantization is not yet supported "
                "for W4A8. The current implementation expects static per-tensor "
                "FP8 scales stored in the checkpoint."
            )

        kernel_supported_gpu = False
        if current_platform.is_rocm():
            from vllm.platforms.rocm import on_gfx950

            kernel_supported_gpu = on_gfx950()

        self.use_aiter_kernel = (
            is_aiter_found_and_supported()
            and self.is_static_input_scheme
            and kernel_supported_gpu
        )

        if not self.use_aiter_kernel:
            logger.warning_once(
                "[W4A8 MXFP4+FP8] Aiter Triton kernel not found. Using emulation mode."
            )

    @classmethod
    def get_min_capability(cls) -> int:
        return 70

    def get_packed_dim(self, dim: int) -> int:
        assert dim % 2 == 0, f"Dimension {dim} must be even for MXFP4 packing"
        return dim // 2

    def create_weights(
        self,
        layer: torch.nn.Module,
        output_partition_sizes: list[int],
        input_size_per_partition: int,
        params_dtype: torch.dtype,
        weight_loader: Callable,
        **kwargs,
    ):
        output_size_per_partition = sum(output_partition_sizes)
        layer.logical_widths = output_partition_sizes
        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition

        # MXFP4 WEIGHT (packed, 2 values per byte)
        weight = PackedvLLMParameter(
            data=torch.empty(
                output_size_per_partition,
                self.get_packed_dim(input_size_per_partition),
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            packed_dim=1,
            packed_factor=self.packed_factor,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE (E8M0 format, per block of 32)
        weight_scale = GroupQuantScaleParameter(
            data=torch.empty(
                output_size_per_partition,
                input_size_per_partition // self.weight_block_size,
                dtype=torch.uint8,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE (FP8 per-tensor static scale)
        if self.is_static_input_scheme:
            input_scale = PerTensorScaleParameter(
                data=torch.empty(
                    len(output_partition_sizes),
                    dtype=torch.float32,
                ),
                weight_loader=weight_loader,
            )
            # Initialize to avoid NaN
            input_scale[:] = torch.finfo(torch.float32).min
            layer.register_parameter("input_scale", input_scale)

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        # Ensuring weights & scales are non-trainable
        layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
        layer.weight_scale = torch.nn.Parameter(
            layer.weight_scale.data, requires_grad=False
        )

        if self.is_static_input_scheme:
            input_scale = layer.input_scale.data
            # For fused modules (QKV), take the max scale
            if input_scale.numel() != 1:
                input_scale = input_scale.max()

            layer.input_scale = torch.nn.Parameter(
                torch.tensor(input_scale, dtype=torch.float32),
                requires_grad=False,
            )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if self.use_aiter_kernel:
            return self._apply_aiter_kernel(layer, x, bias)
        else:
            return self._apply_emulation(layer, x, bias)

    def _apply_aiter_kernel(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        M = x.shape[0]
        out_dtype = x.dtype if self.out_dtype is None else self.out_dtype

        input_scale = layer.input_scale
        x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)

        # Broadcast per-tensor scale to per-row (M, 1) for Aiter kernel
        x_scales = input_scale.expand(M, 1).to(dtype=torch.float32, device=x.device)

        y = rocm_aiter_ops.gemm_a8wfp4(
            x_fp8, layer.weight, x_scales, layer.weight_scale, out_dtype
        )

        if bias is not None:
            y = y + bias

        return y

    def _apply_emulation(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
            dequant_mxfp4,
        )

        weight_dq = dequant_mxfp4(
            layer.weight,
            layer.weight_scale,
            x.dtype,
        )

        input_scale = layer.input_scale
        x_fp8 = (x / input_scale).clamp(self.fp8_min, self.fp8_max).to(self.fp8_dtype)
        x_dq = (x_fp8.to(x.dtype) * input_scale).to(x.dtype)

        return F.linear(x_dq, weight_dq, bias)