Skip to content

vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter

AiterScaledMMLinearKernel

Bases: CutlassScaledMMLinearKernel

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):

    @classmethod
    def get_min_capability(cls) -> int:
        return 90

    @classmethod
    def can_implement(
            cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
        if not current_platform.is_rocm():
            return (
                False,
                "AiterScaledMMLinearKernel requires `aiter` which is not " +
                "currently supported on non-ROCm platform.")

        try:
            import aiter  # noqa: F401 # deliberately attempt to import aiter
        except Exception:
            return (
                False,
                "AiterScaledMMLinearKernel requires `aiter` which is not " +
                "installed on ROCm.")
        # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
        if not (
            envs.VLLM_ROCM_USE_AITER_LINEAR \
            and envs.VLLM_ROCM_USE_AITER
        ):
            return (False, "AiterScaledMMLinearKernel is disabled. " +
                    "Enable by setting `VLLM_ROCM_USE_AITER=1` " +
                    "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
                    "`VLLM_ROCM_USE_AITER_LINEAR` default is True.")

        if not c.input_symmetric:
            return (False,
                    "AiterScaledMMLinearKernel only supports symmetric " +
                    "quantization.")
        return True, None

    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        super().process_weights_after_loading(layer)

    def apply_weights(self,
                      layer: torch.nn.Module,
                      x: torch.Tensor,
                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        `AiterScaledMMLinearKernel` implements a fused version of
            `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
        where scale_a * a and scale_b * b are implemented using numpy-style
        broadcasting.
        Currently only support per-tensor-per-tensor GEMM
        and per-token-per-channel GEMM through AITER
        w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
        ATIER block scaled GEMM and mix-precision GEMM.
        """
        w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)

        # ops.scaled_int8_quant supports both dynamic and static quant:
        # * dynamic, i_s is None and x_s computed from x.
        # * static, i_s is scalar and x_s is i_s.
        symmetric = azp_adj is None
        assert symmetric, ("AiterScaledMMLinearKernel only supports"
                           " symmetric quantization.")
        x_q, x_s, x_zp = ops.scaled_int8_quant(x,
                                               i_s,
                                               i_zp,
                                               symmetric=symmetric)

        assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
                              " symmetric quantization.")
        out_dtype = x.dtype

        assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
        assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
        assert bias is None or bias.shape[0] == w_q.shape[
            1] and bias.dtype == out_dtype

        m = x_q.shape[0]  # a
        n = w_q.shape[1]  # b

        per_tensor_scale_a = (x_s.numel() == 1)
        per_tensor_scale_b = (w_s.numel() == 1)
        per_token_scale_a = (x_s.numel() == m)
        per_channel_scale_b = (w_s.numel() == n)

        # @TODO:
        # Maybe broadcast the per-tensor-scale into per-channel-scale
        # if one of the scale is a per-channel-scale.
        # For now, it only supports:
        # - per-tensor-per-tensor a8w8 scaled GEMM, and
        # - per-token-per-channel a8w8 scaled GEMM
        assert ((per_tensor_scale_a and per_tensor_scale_b)
                or (per_token_scale_a and per_channel_scale_b)), (
                    "Currently only support per-tensor-per-tensor GEMM " +
                    " and per-token-per-channel GEMM through AITER"
                    " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
                    "does not support AITER block scaled GEMM.")

        from aiter import gemm_a8w8_CK

        # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
        # a to be [M, K]
        # b to be [N, K]
        # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
        return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)

apply_weights

apply_weights(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor

AiterScaledMMLinearKernel implements a fused version of output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype) where scale_a * a and scale_b * b are implemented using numpy-style broadcasting. Currently only support per-tensor-per-tensor GEMM and per-token-per-channel GEMM through AITER w8a8 scaled gemm. AiterScaledMMLinearKernel also does not support ATIER block scaled GEMM and mix-precision GEMM.

Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
def apply_weights(self,
                  layer: torch.nn.Module,
                  x: torch.Tensor,
                  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    `AiterScaledMMLinearKernel` implements a fused version of
        `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
    where scale_a * a and scale_b * b are implemented using numpy-style
    broadcasting.
    Currently only support per-tensor-per-tensor GEMM
    and per-token-per-channel GEMM through AITER
    w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
    ATIER block scaled GEMM and mix-precision GEMM.
    """
    w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)

    # ops.scaled_int8_quant supports both dynamic and static quant:
    # * dynamic, i_s is None and x_s computed from x.
    # * static, i_s is scalar and x_s is i_s.
    symmetric = azp_adj is None
    assert symmetric, ("AiterScaledMMLinearKernel only supports"
                       " symmetric quantization.")
    x_q, x_s, x_zp = ops.scaled_int8_quant(x,
                                           i_s,
                                           i_zp,
                                           symmetric=symmetric)

    assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
                          " symmetric quantization.")
    out_dtype = x.dtype

    assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
    assert bias is None or bias.shape[0] == w_q.shape[
        1] and bias.dtype == out_dtype

    m = x_q.shape[0]  # a
    n = w_q.shape[1]  # b

    per_tensor_scale_a = (x_s.numel() == 1)
    per_tensor_scale_b = (w_s.numel() == 1)
    per_token_scale_a = (x_s.numel() == m)
    per_channel_scale_b = (w_s.numel() == n)

    # @TODO:
    # Maybe broadcast the per-tensor-scale into per-channel-scale
    # if one of the scale is a per-channel-scale.
    # For now, it only supports:
    # - per-tensor-per-tensor a8w8 scaled GEMM, and
    # - per-token-per-channel a8w8 scaled GEMM
    assert ((per_tensor_scale_a and per_tensor_scale_b)
            or (per_token_scale_a and per_channel_scale_b)), (
                "Currently only support per-tensor-per-tensor GEMM " +
                " and per-token-per-channel GEMM through AITER"
                " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
                "does not support AITER block scaled GEMM.")

    from aiter import gemm_a8w8_CK

    # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
    # a to be [M, K]
    # b to be [N, K]
    # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
    return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)

can_implement classmethod

can_implement(
    c: ScaledMMLinearLayerConfig,
) -> tuple[bool, Optional[str]]
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
@classmethod
def can_implement(
        cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]:
    if not current_platform.is_rocm():
        return (
            False,
            "AiterScaledMMLinearKernel requires `aiter` which is not " +
            "currently supported on non-ROCm platform.")

    try:
        import aiter  # noqa: F401 # deliberately attempt to import aiter
    except Exception:
        return (
            False,
            "AiterScaledMMLinearKernel requires `aiter` which is not " +
            "installed on ROCm.")
    # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
    if not (
        envs.VLLM_ROCM_USE_AITER_LINEAR \
        and envs.VLLM_ROCM_USE_AITER
    ):
        return (False, "AiterScaledMMLinearKernel is disabled. " +
                "Enable by setting `VLLM_ROCM_USE_AITER=1` " +
                "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
                "`VLLM_ROCM_USE_AITER_LINEAR` default is True.")

    if not c.input_symmetric:
        return (False,
                "AiterScaledMMLinearKernel only supports symmetric " +
                "quantization.")
    return True, None

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
@classmethod
def get_min_capability(cls) -> int:
    return 90

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
    super().process_weights_after_loading(layer)