@PluggableLayer.register("deepseek_v4_multi_head_latent_attention")
class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
"""Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently oot platforms can still use CustomOp.register_oot to
replace MLA layer entirely, although we use PluggableLayer to register
this layer now.
This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens.
The class does the following:
1. MLA Preprocess.
2. Perform multi-head attention to prefill tokens and
multi-query attention to decode tokens separately.
3. Return the output tensor.
"""
# --8<-- [end:multi_head_latent_attention]
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
o_lora_rank: int | None,
mla_modules: DeepseekV4MLAModules,
window_size: int,
compress_ratio: int | None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.n_local_heads = num_heads
self.head_dim = head_dim
self.scale = scale
# FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the
# next supported size. Must match DeepseekV4MLAAttention.padded_heads.
if num_heads <= 64:
self.padded_heads = 64
elif num_heads <= 128:
self.padded_heads = 128
else:
raise ValueError(
f"DeepseekV4 attention does not support {num_heads} heads "
"(must be <= 128)."
)
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.window_size = window_size
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
self.prefix = prefix
# Extract config from vllm_config
config = mla_modules.vllm_config.model_config.hf_config
tp_size = get_tensor_model_parallel_world_size()
# DeepseekV4-specific attributes (num_heads is already TP-adjusted)
self.eps = config.rms_norm_eps
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = head_dim - self.rope_head_dim
self.n_local_groups = config.o_groups // tp_size
self.o_lora_rank = config.o_lora_rank
# Store projection modules
self.fused_wqa_wkv = mla_modules.fused_wqa_wkv
self.q_norm = mla_modules.q_norm
self.wq_b = mla_modules.wq_b
self.kv_norm = mla_modules.kv_norm
self.wo_a = mla_modules.wo_a
self._wo_a_act_quant = QuantFP8(
static=False,
group_shape=GroupShape(1, 128),
use_ue8m0=True,
)
# Bypass packed-for-deepgemm path — we need FP32 scales (not packed
# INT32) so fp8_einsum can handle layout transform internally.
self._wo_a_act_quant.use_deep_gemm_supported = False
self.wo_b = mla_modules.wo_b
# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
from vllm.platforms import current_platform
cap = current_platform.get_device_capability()
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
self._tma_aligned_scales = cap.major >= 10
self.rotary_emb = mla_modules.rotary_emb
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
self.topk_indices_buffer = mla_modules.topk_indices_buffer
self.indexer = mla_modules.indexer
# Per-head RMS normalization for Q (no learnable weights)
self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False)
# TODO(yifan): currently hardcoded for FP8 sparse, make it more generic
head_bytes = (
self.nope_head_dim # 448 fp8 NoPE
+ self.rope_head_dim * 2 # 64 bf16 RoPE
+ self.nope_head_dim // 64 # 7B scale factors
+ 1 # 1B pad
)
self.aux_stream = mla_modules.aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
self.swa_cache_layer = DeepseekV4SWACache(
head_dim=self.head_dim,
window_size=self.window_size,
dtype=torch.uint8,
prefix=f"{prefix}.swa_cache",
cache_config=cache_config,
)
self.mla_attn = DeepseekV4MLAAttention(
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
compress_ratio=self.compress_ratio,
window_size=self.window_size,
head_bytes=head_bytes,
swa_cache_layer=self.swa_cache_layer,
attn_sink=mla_modules.attn_sink, # already padded with -inf
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
indexer=self.indexer,
topk_indices_buffer=self.topk_indices_buffer,
)
# Register this layer in the compilation config's static forward context
# This allows the custom op to retrieve the layer during execution
compilation_config = mla_modules.vllm_config.compilation_config
# HACK
self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention"
if self.layer_name in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {self.layer_name}")
compilation_config.static_forward_context[self.layer_name] = self
# Create the compressor for layers with compress_ratio > 1; after
# creating the DeepseekV4MLAAttention layer to get its cache.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
vllm_config=mla_modules.vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=self.hidden_size,
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.mla_attn.prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
qr_kv, _ = self.fused_wqa_wkv(hidden_states)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
o_padded = torch.empty(
(num_tokens, self.padded_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# Attention (inside custom op for torch.compile boundary)
torch.ops.vllm.deepseek_v4_attention(
hidden_states,
qr,
kv,
positions,
o_padded,
self.layer_name,
)
o = o_padded[:, : self.n_local_heads, :]
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
positions,
self.rotary_emb.cos_sin_cache,
n_groups=self.n_local_groups,
heads_per_group=self.n_local_heads // self.n_local_groups,
nope_dim=self.nope_head_dim,
rope_dim=self.rope_head_dim,
tma_aligned_scales=self._tma_aligned_scales,
)
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
z = torch.empty(
(num_tokens, self.n_local_groups, self.o_lora_rank),
device=o.device,
dtype=torch.bfloat16,
)
torch.ops.vllm.deepseek_v4_fp8_einsum(
o_fp8,
o_scale,
wo_a_fp8,
wo_a_scale,
z,
"bhr,hdr->bhd",
list(self._einsum_recipe),
)
return self.wo_b(z.flatten(1))
def attention_impl(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
# Overlap kv_insert with whichever of indexer/compressor is present.
# Indexer implies compressor; when both exist, compressor rides on the
# aux stream alongside kv_insert so the heavy indexer owns default.
if self.indexer is not None:
indexer = self.indexer
# Local ref so the closure keeps a non-None type for mypy.
assert self.compressor is not None
compressor = self.compressor
def kv_insert_and_compress() -> None:
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
compressor(hidden_states, positions, self.rotary_emb)
maybe_execute_in_parallel(
lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb),
kv_insert_and_compress,
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
elif self.compressor is not None:
# Compressor on default, kv_insert on aux.
compressor = self.compressor
maybe_execute_in_parallel(
lambda: compressor(hidden_states, positions, self.rotary_emb),
lambda: self._fused_qnorm_rope_kv_insert(
q, kv, positions, attn_metadata
),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
else:
# SWA-only layer: no compressor, no overlap.
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
# Handle dummy run (no metadata).
if not isinstance(attn_metadata, dict):
# Reserve _forward_prefill's bf16-gather workspace; the dummy
# run returns before mla_attn runs, so without this the shared
# workspace locks below the real prefill size.
sub = self.mla_attn
swa_only = sub.compress_ratio <= 1
N = (
0
if swa_only
else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio
)
M = N + sub.window_size + sub.max_num_batched_tokens
current_workspace_manager().get_simultaneous(
((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
)
out.zero_()
return
# Pad q to FlashMLA-required head count (64 or 128)
if self.n_local_heads < self.padded_heads:
pad_size = self.padded_heads - self.n_local_heads
q = F.pad(q, (0, 0, 0, pad_size), value=0.0)
# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
self.mla_attn(q, kv, positions, output=out)
def _fused_qnorm_rope_kv_insert(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
attn_metadata: (
dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None
),
) -> None:
if not isinstance(attn_metadata, dict):
return
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_kv_cache = self.swa_cache_layer.kv_cache
swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1)
# Horizontally fused:
# Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE
# KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert
# kv is unchanged; mla_attn reads kv solely via swa_kv_cache.
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q,
kv,
swa_kv_cache_2d,
swa_metadata.slot_mapping,
positions.to(torch.int64),
self.rotary_emb.cos_sin_cache,
self.eps,
swa_metadata.block_size,
)