Skip to content

vllm.model_executor.kernels.linear.mxfp4.flashinfer

FlashInferMxFp4LinearKernel

Bases: MxFp4LinearKernel

MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+).

Source code in vllm/model_executor/kernels/linear/mxfp4/flashinfer.py
class FlashInferMxFp4LinearKernel(MxFp4LinearKernel):
    """MXFP4 W4A4 GEMM via FlashInfer CUTLASS (SM100+)."""

    @classmethod
    def is_supported(
        cls, compute_capability: int | None = None
    ) -> tuple[bool, str | None]:
        if current_platform.has_device_capability(100) and has_flashinfer_cutedsl():
            return True, None
        return False, "FlashInfer + >=sm_100 (Blackwell) required"

    @classmethod
    def can_implement(cls, config: MxFp4LinearLayerConfig) -> tuple[bool, str | None]:
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        N, scale_K = layer.weight_scale.shape
        K = scale_K * _MXFP4_GROUP_SIZE

        # swizzle pads N to the next multiple of 128 for CUTLASS tiling
        padded_N = ((N + 127) // 128) * 128
        layer.weight_scale = Parameter(
            swizzle_mxfp4_scales(layer.weight_scale.data, N, K).reshape(padded_N, -1),
            requires_grad=False,
        )

    def apply_weights(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        from vllm.utils.flashinfer import (
            flashinfer_mxfp4_quantize,
            flashinfer_scaled_fp4_mm,
        )

        weight = layer.weight
        out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
        x_2d = x.reshape(-1, x.shape[-1])

        x_fp4, x_scale = flashinfer_mxfp4_quantize(x_2d)
        out = flashinfer_scaled_fp4_mm(
            x_fp4,
            weight,
            x_scale,
            layer.weight_scale,
            alpha=None,
            out_dtype=x.dtype,
            backend="cute-dsl",
            block_size=_MXFP4_GROUP_SIZE,
            use_nvfp4=False,
        )

        if bias is not None:
            out = out + bias
        return out.view(out_shape)