Skip to content

vllm.v1.attention.ops.deepseek_v4_ops.fused_inv_rope_fp8_quant

Fused inverse RoPE + block-scaled FP8 quantization kernel for DeepseekV4 attention.

Output scale format is pre-transformed (MN-major TMA-aligned; FP32 on SM90, INT32-packed UE8M0 on SM100) so fp8_einsum skips transform_sf_into_required_layout.

fused_inv_rope_fp8_quant

fused_inv_rope_fp8_quant(
    o: Tensor,
    positions: Tensor,
    cos_sin_cache: Tensor,
    n_groups: int,
    heads_per_group: int,
    nope_dim: int = 448,
    rope_dim: int = 64,
    quant_group_size: int = 128,
    tma_aligned_scales: bool = False,
) -> tuple[Tensor, Tensor]

Fused inverse RoPE + block-scaled FP8 quantization.

Parameters:

Name Type Description Default
o Tensor

Attention output [num_tokens, num_heads, head_dim] bf16.

required
positions Tensor

Token positions [num_tokens] int64.

required
cos_sin_cache Tensor

Precomputed [max_pos, rope_dim] with cos||sin.

required
n_groups int

Number of output groups.

required
heads_per_group int

Heads per group.

required
nope_dim int

Non-RoPE dimensions per head (default 448).

448
rope_dim int

RoPE dimensions per head (default 64).

64
quant_group_size int

FP8 quantization block size (default 128).

128
tma_aligned_scales bool

Output INT32 packed UE8M0 for SM100 (True) or FP32 for SM90 (False).

False

Returns:

Name Type Description
o_fp8 Tensor

[T, G, D] float8_e4m3fn, strides (D, T*D, 1).

o_scale Tensor

Pre-transformed scale tensor for fp8_einsum.

Source code in vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py
def fused_inv_rope_fp8_quant(
    o: torch.Tensor,
    positions: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    n_groups: int,
    heads_per_group: int,
    nope_dim: int = 448,
    rope_dim: int = 64,
    quant_group_size: int = 128,
    tma_aligned_scales: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fused inverse RoPE + block-scaled FP8 quantization.

    Args:
        o: Attention output [num_tokens, num_heads, head_dim] bf16.
        positions: Token positions [num_tokens] int64.
        cos_sin_cache: Precomputed [max_pos, rope_dim] with cos||sin.
        n_groups: Number of output groups.
        heads_per_group: Heads per group.
        nope_dim: Non-RoPE dimensions per head (default 448).
        rope_dim: RoPE dimensions per head (default 64).
        quant_group_size: FP8 quantization block size (default 128).
        tma_aligned_scales: Output INT32 packed UE8M0 for SM100 (True)
                            or FP32 for SM90 (False).

    Returns:
        o_fp8: [T, G, D] float8_e4m3fn, strides (D, T*D, 1).
        o_scale: Pre-transformed scale tensor for fp8_einsum.
    """
    from vllm.utils.deep_gemm import get_tma_aligned_size

    num_tokens, num_heads, head_dim = o.shape
    assert num_heads == n_groups * heads_per_group
    assert head_dim == nope_dim + rope_dim
    assert head_dim % quant_group_size == 0
    assert nope_dim % quant_group_size == (quant_group_size - rope_dim)
    assert rope_dim % 2 == 0
    assert cos_sin_cache.shape[-1] == rope_dim
    assert cos_sin_cache.dtype == torch.float32

    d = heads_per_group * head_dim
    num_scale_blocks = d // quant_group_size
    chunks_per_head = head_dim // quant_group_size

    fp8_dtype = torch.float8_e4m3fn
    fp8_max = torch.finfo(fp8_dtype).max

    fp8_buf = torch.empty(
        (n_groups, num_tokens, d),
        dtype=fp8_dtype,
        device=o.device,
    )

    tma_aligned_T = get_tma_aligned_size(num_tokens, 4)
    if tma_aligned_scales:
        packed_sf_k = (num_scale_blocks + 3) // 4
        scale_buf = torch.empty(
            n_groups * packed_sf_k * tma_aligned_T,
            dtype=torch.int32,
            device=o.device,
        ).as_strided(
            (n_groups, num_tokens, packed_sf_k),
            (packed_sf_k * tma_aligned_T, 1, tma_aligned_T),
        )
    else:
        scale_buf = torch.empty(
            n_groups * num_scale_blocks * tma_aligned_T,
            dtype=torch.float32,
            device=o.device,
        ).as_strided(
            (n_groups, num_tokens, num_scale_blocks),
            (num_scale_blocks * tma_aligned_T, 1, tma_aligned_T),
        )

    common_args = dict(
        heads_per_group=heads_per_group,
        o_stride_token=o.stride(0),
        o_stride_head=o.stride(1),
        cache_stride_pos=cos_sin_cache.stride(0),
        fp8_stride_group=fp8_buf.stride(0),
        fp8_stride_token=fp8_buf.stride(1),
        scale_stride_group=scale_buf.stride(0),
        scale_stride_k=scale_buf.stride(2),
        fp8_max=fp8_max,
        eps=1e-10,
        QUANT_GROUP_SIZE=quant_group_size,
        CHUNKS_PER_HEAD=chunks_per_head,
        ROPE_START=nope_dim % quant_group_size,
        HALF_ROPE=rope_dim // 2,
        TMA_ALIGNED_SCALES=tma_aligned_scales,
        num_stages=1,
        launch_pdl=False,
    )

    grid = (tma_aligned_T, n_groups * heads_per_group)
    _fused_inv_rope_fp8_quant_per_head[grid](
        o,
        positions,
        cos_sin_cache,
        fp8_buf,
        scale_buf,
        num_tokens,
        **common_args,
        num_warps=1,
    )

    return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1)