Skip to content

vllm.model_executor.layers.quantization.utils.mxfp4_utils

OCP_MX_BLOCK_SIZE module-attribute

OCP_MX_BLOCK_SIZE = 32

per_token_group_quant_mxfp4

per_token_group_quant_mxfp4(
    x: Tensor,
    block_k: int,
    scale_calculation_mode: str = "even",
) -> tuple[Tensor, Tensor]
Source code in vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
def per_token_group_quant_mxfp4(x: torch.Tensor,
                                block_k: int,
                                scale_calculation_mode: str = "even"
                                ) -> tuple[torch.Tensor, torch.Tensor]:
    try:
        from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
            fake_quantize_fp4_fp6_per_group_with_scale)
        from quark.torch.quantization.utils import (even_round,
                                                    reshape_to_blocks)
    except ImportError as err:
        raise ImportError("The package `amd-quark` is required to use "
                          "MX-FP4 models. Please install it with `pip install "
                          "amd-quark`.") from err

    axis = -1
    block_x = reshape_to_blocks(x, block_k, axis)
    amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
    amax = amax.squeeze(-1)

    # TODO: there are other rounding strategies supported in quark and in the
    # config.json that we do not check for here!
    if scale_calculation_mode != "even":
        raise NotImplementedError(
            f"Scale calculation mode {scale_calculation_mode} is not yet "
            "supported in MX-FP4 quantization")
    scale = even_round(amax, "fp4")

    # Apply dequantize(quantize(x)).
    x = fake_quantize_fp4_fp6_per_group_with_scale(
        x,
        scale.to(x.device),
        axis=axis,
        group_size=block_k,
        quant_dtype="fp4",
    )

    return x, scale