Skip to content

vllm.models.deepseek_v4.common.ops.fused_mtp_input_rmsnorm

Fused MTP-input RMSNorm: enorm (with mask-zero at position 0) + hnorm.

Replaces the eager sequence at the top of the MTP draft forward

inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = previous_hidden_states.view(-1, hc_mult, H) previous_hidden_states = self.hnorm(previous_hidden_states)

which lowers to ~6 small kernels (CompareEq, where, Fill, enorm rms_norm, hnorm rms_norm, plus aten elementwise helpers) on the breakable-cudagraph path. Math is preserved: positions==0 → masked row → zero RMS output regardless of weight.

A single grid (T, hc_mult+1) drives both norms: task 0 is enorm on inputs_embeds[token, :], task k+1 is hnorm on previous_hidden_states[token, k, :].

fused_mtp_input_rmsnorm

fused_mtp_input_rmsnorm(
    inputs_embeds: Tensor,
    positions: Tensor,
    previous_hidden_states: Tensor,
    enorm_weight: Tensor,
    hnorm_weight: Tensor,
    eps: float,
    hc_mult: int,
) -> tuple[Tensor, Tensor]

Returns (enorm_out, hnorm_out).

enorm_out has the same shape as inputs_embeds (2D, [T, H]). hnorm_out has the same shape as previous_hidden_states (3D, [T, hc_mult, H]). previous_hidden_states must already be reshaped to 3D.

Source code in vllm/models/deepseek_v4/common/ops/fused_mtp_input_rmsnorm.py
def fused_mtp_input_rmsnorm(
    inputs_embeds: torch.Tensor,
    positions: torch.Tensor,
    previous_hidden_states: torch.Tensor,
    enorm_weight: torch.Tensor,
    hnorm_weight: torch.Tensor,
    eps: float,
    hc_mult: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Returns (enorm_out, hnorm_out).

    enorm_out has the same shape as inputs_embeds (2D, [T, H]).
    hnorm_out has the same shape as previous_hidden_states (3D, [T, hc_mult, H]).
    previous_hidden_states must already be reshaped to 3D.
    """
    assert inputs_embeds.ndim == 2
    assert previous_hidden_states.ndim == 3
    assert previous_hidden_states.shape[1] == hc_mult
    assert inputs_embeds.shape[0] == previous_hidden_states.shape[0], (
        "token dim mismatch"
    )
    assert (
        inputs_embeds.shape[1]
        == previous_hidden_states.shape[2]
        == enorm_weight.shape[0]
        == hnorm_weight.shape[0]
    )
    assert inputs_embeds.is_contiguous() and previous_hidden_states.is_contiguous()
    assert enorm_weight.is_contiguous() and hnorm_weight.is_contiguous()

    num_tokens, hidden = inputs_embeds.shape
    enorm_out = torch.empty_like(inputs_embeds)
    hnorm_out = torch.empty_like(previous_hidden_states)
    if num_tokens == 0:
        return enorm_out, hnorm_out

    block_size = triton.next_power_of_2(hidden)
    _fused_mtp_input_rmsnorm_kernel[(num_tokens, hc_mult + 1)](
        inputs_embeds,
        positions,
        previous_hidden_states,
        enorm_weight,
        hnorm_weight,
        enorm_out,
        hnorm_out,
        eps,
        HIDDEN=hidden,
        HC_MULT=hc_mult,
        BLOCK_SIZE=block_size,
    )
    return enorm_out, hnorm_out

mtp_shared_head_rmsnorm

mtp_shared_head_rmsnorm(
    hidden_states: Tensor, weight: Tensor, eps: float
) -> Tensor

RMSNorm for MTP's SharedHead.norm, on (T, H) bf16 input.

Uses the same _rmsnorm_row body as fused_mtp_input_rmsnorm so the MTP draft path runs one consistent RMSNorm implementation end to end.

Source code in vllm/models/deepseek_v4/common/ops/fused_mtp_input_rmsnorm.py
def mtp_shared_head_rmsnorm(
    hidden_states: torch.Tensor,
    weight: torch.Tensor,
    eps: float,
) -> torch.Tensor:
    """RMSNorm for MTP's SharedHead.norm, on (T, H) bf16 input.

    Uses the same ``_rmsnorm_row`` body as ``fused_mtp_input_rmsnorm`` so the
    MTP draft path runs one consistent RMSNorm implementation end to end.
    """
    assert hidden_states.ndim == 2
    assert hidden_states.is_contiguous()
    assert weight.is_contiguous()
    num_tokens, hidden = hidden_states.shape
    out = torch.empty_like(hidden_states)
    if num_tokens == 0:
        return out
    block_size = triton.next_power_of_2(hidden)
    _mtp_shared_head_rmsnorm_kernel[(num_tokens,)](
        hidden_states,
        weight,
        out,
        eps,
        HIDDEN=hidden,
        BLOCK_SIZE=block_size,
    )
    return out