Skip to content

vllm.kernels.triton.qkv_padded_fp8_quant

Stride-aware FP8 quantization with head_dim padding for ViT attention.

Reads directly from non-contiguous QKV views using 3D strides and pads head_dim to a multiple of 16 for cuDNN compatibility.

quantize_fp8_maybe_pad_head_dim

quantize_fp8_maybe_pad_head_dim(
    tensor: Tensor,
    scale: Tensor,
    fp8_quant: QuantFP8,
    skip_scale: bool = False,
) -> Tensor

Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16 only when needed.

Accepts (S, H, D) or (B, S, H, D) input. Uses fp8_quant (a :class:QuantFP8 CustomOp) when head_dim is already aligned to 16 (no padding); otherwise falls back to a stride-aware Triton kernel that pads head_dim to a multiple of 16.

Source code in vllm/kernels/triton/qkv_padded_fp8_quant.py
def quantize_fp8_maybe_pad_head_dim(
    tensor: torch.Tensor,
    scale: torch.Tensor,
    fp8_quant: QuantFP8,
    skip_scale: bool = False,
) -> torch.Tensor:
    """Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16
    only when needed.

    Accepts (S, H, D) or (B, S, H, D) input. Uses ``fp8_quant`` (a
    :class:`QuantFP8` CustomOp) when head_dim is already aligned to 16
    (no padding); otherwise falls back to a stride-aware Triton kernel
    that pads head_dim to a multiple of 16.
    """
    head_dim = tensor.shape[-1]
    if head_dim % 16 != 0:
        return quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=skip_scale)

    if skip_scale:
        return tensor.to(current_platform.fp8_dtype())

    # QuantFP8 expects 2D: flatten all dims except (H, D).
    orig_shape = tensor.shape
    total_tokens = tensor.numel() // (orig_shape[-1] * orig_shape[-2])
    tensor_2d = tensor.reshape(total_tokens, -1)
    fp8_tensor, _ = fp8_quant(tensor_2d, scale=scale)
    return fp8_tensor.reshape(orig_shape)

quantize_fp8_pad_head_dim_triton

quantize_fp8_pad_head_dim_triton(
    tensor: Tensor,
    scale: Tensor,
    skip_scale: bool = False,
    block_m: int | None = None,
    block_n: int | None = None,
    num_warps: int | None = None,
) -> Tensor

Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16.

Reads directly from the input using its 3D strides, so non-contiguous views (e.g. Q/K/V slices from an interleaved QKV buffer) are handled without an extra copy. Output is always a fresh contiguous tensor with shape (S, H, padded_D).

Source code in vllm/kernels/triton/qkv_padded_fp8_quant.py
def quantize_fp8_pad_head_dim_triton(
    tensor: torch.Tensor,
    scale: torch.Tensor,
    skip_scale: bool = False,
    block_m: int | None = None,
    block_n: int | None = None,
    num_warps: int | None = None,
) -> torch.Tensor:
    """Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16.

    Reads directly from the input using its 3D strides, so non-contiguous
    views (e.g. Q/K/V slices from an interleaved QKV buffer) are handled
    without an extra copy.  Output is always a fresh contiguous tensor
    with shape (S, H, padded_D).
    """
    if not HAS_TRITON:
        raise RuntimeError("Triton is required to quantize with head_dim padding.")

    original_shape = tensor.shape
    if tensor.dim() == 4:
        tensor = tensor.view(-1, tensor.shape[-2], tensor.shape[-1])
    assert tensor.dim() == 3, f"Expected 3D input (S, H, D), got {tensor.dim()}D"
    S, H, D = tensor.shape
    padded_head_dim = round_up(D, 16)
    out_dtype = current_platform.fp8_dtype()
    output = torch.empty(
        (S, H, padded_head_dim),
        device=tensor.device,
        dtype=out_dtype,
    )

    scale_1d = scale.reshape(-1)
    n_rows = S * H

    if block_m is None or block_n is None or num_warps is None:
        block_m, block_n, num_warps = _get_fp8_pad_quant_config(padded_head_dim)

    grid = (
        triton.cdiv(n_rows, block_m),
        triton.cdiv(padded_head_dim, block_n),
    )

    _quantize_pad_fp8_kernel[grid](
        tensor,
        output,
        scale_1d,
        tensor.stride(0),
        tensor.stride(1),
        tensor.stride(2),
        output.stride(0),
        output.stride(1),
        output.stride(2),
        H,
        n_rows,
        D,
        padded_head_dim,
        _FP8_MIN,
        _FP8_MAX,
        SKIP_SCALE=skip_scale,
        BLOCK_M=block_m,
        BLOCK_N=block_n,
        num_warps=num_warps,
    )

    return output.view((*original_shape[:-1], padded_head_dim))