Skip to content

vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q

_e2m1_nibble

_e2m1_nibble(x)

Quantize fp32 x (already scale-divided) to E2M1 4-bit nibble in uint8. Matches torch.bucketize with boundaries [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] and right=False (each boundary belongs to the lower bucket), plus sign bit.

Source code in vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py
@triton.jit
def _e2m1_nibble(x):
    """Quantize fp32 x (already scale-divided) to E2M1 4-bit nibble in uint8.
    Matches torch.bucketize with boundaries
    [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] and right=False (each boundary
    belongs to the lower bucket), plus sign bit."""
    abs_x = tl.minimum(tl.abs(x), 6.0)
    code = tl.where(
        abs_x <= 0.25,
        0.0,
        tl.where(
            abs_x <= 0.75,
            1.0,
            tl.where(
                abs_x <= 1.25,
                2.0,
                tl.where(
                    abs_x <= 1.75,
                    3.0,
                    tl.where(
                        abs_x <= 2.5,
                        4.0,
                        tl.where(abs_x <= 3.5, 5.0, tl.where(abs_x <= 5.0, 6.0, 7.0)),
                    ),
                ),
            ),
        ),
    )
    code_u8 = code.to(tl.uint8)
    sign = ((x < 0) & (code_u8 != 0)).to(tl.uint8)
    return code_u8 | (sign << 3)

_quantize_mxfp4_pair

_quantize_mxfp4_pair(x_lo, x_hi)

Quantize a block of MXFP4_BLOCK_SIZE fp32 values given as two interleaved halves (x_lo = values at even positions in the block, x_hi = values at odd positions). Returns: - packed : uint8[BLOCK/2] (low nibble = quant(x_lo), high = quant(x_hi)) - ue8m0 : scalar uint8 (block scale = 2^(ue8m0 - 127))

Source code in vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py
@triton.jit
def _quantize_mxfp4_pair(x_lo, x_hi):
    """Quantize a block of MXFP4_BLOCK_SIZE fp32 values given as two
    interleaved halves (x_lo = values at even positions in the block,
    x_hi = values at odd positions). Returns:
        - packed : uint8[BLOCK/2]  (low nibble = quant(x_lo), high = quant(x_hi))
        - ue8m0  : scalar uint8    (block scale = 2^(ue8m0 - 127))
    """
    amax = tl.maximum(tl.max(tl.abs(x_lo)), tl.max(tl.abs(x_hi)))
    amax = tl.maximum(amax, 1e-4)
    # ue8m0 block scale: 2^ceil(log2(amax/6.0)).
    log2_ratio = tl.math.ceil(tl.math.log2(amax / 6.0))
    log2_ratio = tl.minimum(tl.maximum(log2_ratio, -127.0), 127.0)
    scale = tl.math.exp2(log2_ratio)
    ue8m0 = (log2_ratio + 127.0).to(tl.uint8)

    inv_scale = 1.0 / scale
    lo_nib = _e2m1_nibble(x_lo * inv_scale)
    hi_nib = _e2m1_nibble(x_hi * inv_scale)
    packed = lo_nib | (hi_nib << 4)
    return packed, ue8m0

fused_indexer_q_rope_quant

fused_indexer_q_rope_quant(
    positions: Tensor,
    index_q: Tensor,
    index_q_cos_sin_cache: Tensor,
    index_weights: Tensor,
    index_weights_softmax_scale: float,
    index_weights_head_scale: float,
    use_fp4: bool = False,
) -> tuple[Tensor | tuple[Tensor, Tensor], Tensor]

Fused RoPE + quantize Q for the sparse indexer.

Weight-fold semantics (important — the two paths differ):

FP8 path (use_fp4=False, default): q_fp8 : (T, H, HEAD_DIM) float8_e4m3fn, per-token-per-head scalar scale (NOT stored — folded into weights below) weights_out = weights * q_scale * softmax_scale * head_scale Rationale: a single per-token q_scale is a scalar the downstream FP8 logits kernel would otherwise multiply in. Folding it into weights avoids emitting a separate tensor and is free for the logits kernel.

MXFP4 path (use_fp4=True): q_packed : (T, H, HEAD_DIM // 2) uint8 (2 E2M1 nibbles per byte) q_scale : (T, H, HEAD_DIM // MXFP4_BLOCK_SIZE) uint8 ue8m0 bytes weights_out = weights * softmax_scale * head_scale Rationale: MXFP4 has PER-BLOCK (32-element) scales that live with the Q values — they cannot be folded into a per-token weight scalar, so weights carries only the softmax and head scales.

Returns (q_quant, weights_out) where q_quant is either a Tensor (FP8) or a (values, scales) tuple (MXFP4). This matches the union type accepted by SparseAttnIndexer.forward_*.

Source code in vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py
def fused_indexer_q_rope_quant(
    positions: torch.Tensor,
    index_q: torch.Tensor,
    index_q_cos_sin_cache: torch.Tensor,
    # Index weights
    index_weights: torch.Tensor,
    index_weights_softmax_scale: float,
    index_weights_head_scale: float,
    use_fp4: bool = False,
) -> tuple[
    torch.Tensor | tuple[torch.Tensor, torch.Tensor],
    torch.Tensor,
]:
    """Fused RoPE + quantize Q for the sparse indexer.

    Weight-fold semantics (important — the two paths differ):

    FP8 path (use_fp4=False, default):
        q_fp8      : (T, H, HEAD_DIM) float8_e4m3fn, per-token-per-head
                     scalar scale (NOT stored — folded into weights below)
        weights_out = weights * q_scale * softmax_scale * head_scale
        Rationale: a single per-token q_scale is a scalar the downstream FP8
        logits kernel would otherwise multiply in. Folding it into `weights`
        avoids emitting a separate tensor and is free for the logits kernel.

    MXFP4 path (use_fp4=True):
        q_packed   : (T, H, HEAD_DIM // 2) uint8 (2 E2M1 nibbles per byte)
        q_scale    : (T, H, HEAD_DIM // MXFP4_BLOCK_SIZE) uint8 ue8m0 bytes
        weights_out = weights * softmax_scale * head_scale
        Rationale: MXFP4 has PER-BLOCK (32-element) scales that live with
        the Q values — they cannot be folded into a per-token weight
        scalar, so `weights` carries only the softmax and head scales.

    Returns (q_quant, weights_out) where q_quant is either a Tensor (FP8) or
    a (values, scales) tuple (MXFP4). This matches the union type accepted
    by `SparseAttnIndexer.forward_*`.
    """
    assert positions.ndim == 1
    assert index_q.ndim == 3
    assert index_q_cos_sin_cache.ndim == 2

    num_tokens = positions.shape[0]
    num_index_q_heads = index_q.shape[1]
    index_q_head_dim = index_q.shape[2]

    index_weights_out = torch.empty_like(index_weights, dtype=torch.float32)

    if use_fp4:
        assert index_q_head_dim % MXFP4_BLOCK_SIZE == 0, (
            f"head_dim={index_q_head_dim} must be a multiple of MXFP4 block "
            f"size {MXFP4_BLOCK_SIZE}"
        )
        num_scale_blocks = index_q_head_dim // MXFP4_BLOCK_SIZE
        index_q_packed = torch.empty(
            (num_tokens, num_index_q_heads, index_q_head_dim // 2),
            dtype=torch.uint8,
            device=index_q.device,
        )
        index_q_scale = torch.empty(
            (num_tokens, num_index_q_heads, num_scale_blocks),
            dtype=torch.uint8,
            device=index_q.device,
        )
        _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)](
            positions,
            index_q,
            index_q.stride(0),
            index_q.stride(1),
            index_q_cos_sin_cache,
            index_q_cos_sin_cache.stride(0),
            index_q_cos_sin_cache.shape[-1] // 2,
            index_q_packed,
            index_q_packed.stride(0),
            index_q_packed.stride(1),
            index_q_scale,
            index_q_scale.stride(0),
            index_q_scale.stride(1),
            index_q_head_dim,
            MXFP4_BLOCK_SIZE,
            index_weights,
            index_weights.stride(0),
            index_weights_softmax_scale,
            index_weights_head_scale,
            index_weights_out,
            index_weights_out.stride(0),
            num_warps=1,  # TODO: Tune this
        )
        # Values stay uint8 (2 E2M1 nibbles per byte). Scales are 4 ue8m0
        # bytes per (token, head) reinterpreted as one int32, then squeezed
        # from (T, H, 1) to (T, H) to match DeepGEMM's expected q_sf rank
        # (prefill wants 2-D (seq_len, num_heads); decode reshapes this to
        # 3-D (batch, next_n, num_heads)).
        return (
            index_q_packed,
            index_q_scale.view(torch.int32).squeeze(-1),
        ), index_weights_out

    index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn)
    _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)](
        positions,
        index_q,
        index_q.stride(0),
        index_q.stride(1),
        index_q_cos_sin_cache,
        index_q_cos_sin_cache.stride(0),
        index_q_cos_sin_cache.shape[-1] // 2,
        index_q_fp8,
        index_q_fp8.stride(0),
        index_q_fp8.stride(1),
        index_q_head_dim,
        index_weights,
        index_weights.stride(0),
        index_weights_softmax_scale,
        index_weights_head_scale,
        index_weights_out,
        index_weights_out.stride(0),
        num_warps=1,  # TODO: Tune this
    )
    return index_q_fp8, index_weights_out