Skip to content

vllm.v1.attention.backends.mla.common

MLA Common Components

This file implements common components for MLA implementations.

First we define:

Sq as Q sequence length Skv as KV sequence length

MLA has two possible ways of computing, a data-movement friendly approach and a compute friendly approach, we generally want to use the compute friendly approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) and the data-movement friendly approach for "decode" (i.e. the ratio Sq / Skv is "large").

NOTE what we deem small and large is currently determined by if its labelled prefill or decode by the scheduler, but this is something we should probably tune.

Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).

Deepseek's MLA attention works the following way: * Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention.

Below is example of both paths assuming batchsize = 1

More Extent Definitions:

C Context length, Skv - Sq H hidden size N number of attention heads Lq latent dimension for Q 1536 in DSV3 Lkv latent dimension for K/V 512 in DSV3 P nope dimension, no rope. 128 in DSV3 R rope dimension, goes through rope. 64 in DSV3 V V head dim. 128 in DSV3

Vector/Matrix Definitions

h_t hidden states (input to attention) shape [Sq, H] q_c latent/compressed Q shape [Sq, Lq] q_nope uncompressed Q (no-rope) shape [Sq, N, P] q_pe uncompressed Q (rope) shape [Sq, N, R] kv_c latent/compressed KV shape [Skv, Lkv] k_pe decoupled k position embeddings shape [Skv, R] new_kv_c new kv_c from current iter shape [Sq, Lkv] new_k_pe new k_pe from current iter shape [Sq, R] cache_kv_c cached k_c from previous iters shape [C, Lkv] cache_k_pe cached k_pe from previous iters shape [C, R] W_DQ project h_t to q_c shape [H, Lq] W_UQ project q_c to q_nope shape [Lq, N * P] W_QR project q_c to q_pe shape [Lq, N * R] W_DKV project h_t to kv_c shape [H, Lkv] W_UK project kv_c to k_nope shape [Lkv, N, P] W_KR project h_t to k_pe shape [H, R] W_UV project kv_c to v shape [Lkv, N, V] W_O project v to h_t shape [N * V, H]

Compute Friendly Approach (i.e. "_forward_prefill"):

q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)

// MHA with QK headdim = P + R // V headdim = V // spda_o shape [Sq, N, V] spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v ) return spda_o @ W_O

in the actual code,

kv_b_proj is [W_UK; W_UV] concatenated per head q_b_proj is [W_UQ; W_QR] concatenated per head out_proj is W_O

Data-Movement Friendly Approach (i.e. "_forward_decode"):

Runtime q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(-1, N, P) ql_nope = einsum("snh,lnh->snl", q, W_UK) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)

// MQA with QK headdim = Lkv + R // V headdim = Lkv // spda_o shape [Sq, N, Lkv] // NOTE: this is less compute-friendly since Lkv > P // but is more data-movement friendly since its MQA vs MHA spda_o = scaled_dot_product_attention( torch.cat([ql_nope, q_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1), kv_c )

o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) return o.view(-1, N * V) @ self.num_heads @ W_O

Chunked Prefill

For chunked prefill we want to use the compute friendly algorithm. We are assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. Sq) is small.

However, the compute-friendly approach can potentially run out of memory if Skv is large due to: k_nope = (kv_c @ W_UK).view(Skv, N, P)

To mitigate this, we chunk the computation of attention with respect to the current context (i.e. cache_kv_c and cache_k_pe) so that we can used a fixed workspace size.

The chunked prefill approach is as follows:

MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage

q_c = h_t @ W_DQ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)

// MHA between queries and new KV // with QK headdim = P + R // V headdim = V // curr_o shape [Sq, N, V] // curr_lse shape [N, Sq], this is just order FA returns curr_o, curr_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), new_v, casual=True, return_softmax_lse=True )

// Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): chunk_start = chunk_idx * MCC chunk_end = min(chunk_start + MCC, C) Sc = chunk_end - chunk_start cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)

chunk_o, chunk_lse = scaled_dot_product_attention(
    torch.cat([q_nope, q_pe], dim=-1),
    torch.cat([cache_k_nope_chunk,
               cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
               dim=-1),
    cache_v_chunk,
    casual=False,
    return_softmax_lse=True
)

curr_o, curr_lse = merge_attn_states(
    suffix_output=curr_o,
    suffix_lse=curr_lse,
    prefix_output=chunk_o,
    prefix_lse=chunk_lse,
)

return curr_o @ W_O

D module-attribute

M module-attribute

M = TypeVar('M', bound=MLACommonMetadata)

is_vllm_fa module-attribute

is_vllm_fa = True

logger module-attribute

logger = init_logger(__name__)

MLACommonBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/mla/common.py
class MLACommonBackend(AttentionBackend):

    accept_output_buffer: bool = True

    @staticmethod
    def get_name() -> str:
        return "TRITON_MLA_VLLM_V1"

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return MLACommonMetadata

    @staticmethod
    def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
        return MLACommonMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
    ) -> tuple[int, ...]:
        return (num_blocks, block_size, head_size)

    @staticmethod
    def get_supported_head_sizes() -> list[int]:
        return [576]

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[MLACommonMetadataBuilder]
Source code in vllm/v1/attention/backends/mla/common.py
@staticmethod
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
    return MLACommonMetadataBuilder

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/mla/common.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,  # assumed to be 1 for MLA
    head_size: int,
) -> tuple[int, ...]:
    return (num_blocks, block_size, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/mla/common.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return MLACommonMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/common.py
@staticmethod
def get_name() -> str:
    return "TRITON_MLA_VLLM_V1"

get_supported_head_sizes staticmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/mla/common.py
@staticmethod
def get_supported_head_sizes() -> list[int]:
    return [576]

MLACommonDecodeMetadata dataclass

Source code in vllm/v1/attention/backends/mla/common.py
@dataclass
class MLACommonDecodeMetadata:
    block_table: torch.Tensor
    seq_lens: torch.Tensor

block_table instance-attribute

block_table: Tensor

seq_lens instance-attribute

seq_lens: Tensor

__init__

__init__(block_table: Tensor, seq_lens: Tensor) -> None

MLACommonImpl

Bases: MLAAttentionImpl[M], Generic[M]

NOTE: Please read the comment at the top of the file before trying to understand this class

Source code in vllm/v1/attention/backends/mla/common.py
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[dict[str, Any]],
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
    ) -> None:
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported for MLA")

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.kv_b_proj = kv_b_proj

        # Handle the differences between the flash_attn_varlen from flash_attn
        # and the one from vllm_flash_attn. The former is used on RoCM and the
        # latter has an additional parameter to control FA2 vs FA3
        self.flash_attn_varlen_func = flash_attn_varlen_func
        self.vllm_flash_attn_version = get_flash_attn_version()
        if self.vllm_flash_attn_version is not None:
            self.flash_attn_varlen_func = \
                functools.partial(flash_attn_varlen_func,
                                  fa_version=self.vllm_flash_attn_version)

        # For MLA the v head dim is smaller than qk head dim so we pad out
        # v with 0s to match the qk head dim for attention backends that do
        # not support different headdims
        # We don't need to pad V if we are on a hopper system with FA3
        self._pad_v = self.vllm_flash_attn_version is None or not (
            self.vllm_flash_attn_version == 3
            and current_platform.get_device_capability()[0] == 9)

    def _flash_attn_varlen_diff_headdims(self,
                                         q,
                                         k,
                                         v,
                                         return_softmax_lse=False,
                                         softmax_scale=None,
                                         **kwargs):
        maybe_padded_v = v
        if self._pad_v:
            maybe_padded_v = torch.nn.functional.pad(
                v, [0, q.shape[-1] - v.shape[-1]], value=0)

        if is_vllm_fa:
            kwargs["return_softmax_lse"] = return_softmax_lse
        else:
            # ROCm leverages the upstream flash_attn, which takes a parameter
            # called "return_attn_probs" instead of return_softmax_lse
            kwargs["return_attn_probs"] = return_softmax_lse

        attn_out = self.flash_attn_varlen_func(
            q=q,
            k=k,
            v=maybe_padded_v,
            softmax_scale=softmax_scale,
            **kwargs,
        )

        # Unpack the output if there is multiple results
        lse = None
        if isinstance(attn_out, tuple):
            attn_out, lse = attn_out[0], attn_out[1]

        # Remain consistent with old `flash_attn_varlen_func` where there
        # is only one output tensor if `return_softmax_lse` is False.
        if return_softmax_lse:
            return attn_out, lse
        return attn_out

    def _v_up_proj(self, x):
        # Convert from (B, N, L) to (N, B, L)
        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
        # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
        x = torch.bmm(x, self.W_UV)
        # Convert from (N, B, V) to (B, N * V)
        return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)

    def process_weights_after_loading(self, act_dtype: torch.dtype):

        def get_layer_weight(layer):
            WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
            for attr in WEIGHT_NAMES:
                if hasattr(layer, attr):
                    return getattr(layer, attr)
            raise AttributeError(
                f"Layer '{layer}' has no recognized weight attribute:"
                f" {WEIGHT_NAMES}.")

        def get_and_maybe_dequant_weights(layer: LinearBase):
            if not isinstance(layer.quant_method, UnquantizedLinearMethod):
                # NOTE: This should only be used offline, since it's O(N^3)
                eye = torch.eye(layer.input_size_per_partition,
                                dtype=act_dtype,
                                device=get_layer_weight(layer).device)
                dequant_weights = layer.quant_method.apply(layer,
                                                           eye,
                                                           bias=None)
                del eye
                # standardize to (output, input)
                return dequant_weights.T
            return layer.weight

        # we currently do not have quantized bmm's which are needed for
        # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
        # the bmm's in 16-bit, the extra memory overhead of this is fairly low
        kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
        assert kv_b_proj_weight.shape == (
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
                f"{kv_b_proj_weight.shape=}, "
                f"{self.kv_lora_rank=}, "
                f"{self.num_heads=}, "
                f"{self.qk_nope_head_dim=}, "
                f"{self.v_head_dim=}")
        kv_b_proj_weight = kv_b_proj_weight.view(
            self.kv_lora_rank,
            self.num_heads,
            self.qk_nope_head_dim + self.v_head_dim,
        )

        W_UK, W_UV = kv_b_proj_weight.split(
            [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        # Convert from (L, N, V) to (N, L, V)
        self.W_UV = W_UV.transpose(0, 1)
        # Convert from (L, N, P) to (N, P, L)
        self.W_UK_T = W_UK.permute(1, 2, 0)

    def _compute_prefill_context(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
    ):
        assert attn_metadata.prefill is not None
        prefill_metadata = attn_metadata.prefill
        assert prefill_metadata.chunked_context is not None

        output = None
        iters = len(prefill_metadata.chunked_context.seq_tot)
        workspace = prefill_metadata.chunked_context.workspace

        for i in range(iters):
            toks = prefill_metadata.chunked_context.seq_tot[i]

            ops.gather_cache(
                src_cache=kv_c_and_k_pe_cache,
                dst=workspace,
                block_table=prefill_metadata.block_table,
                cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
                batch_size=attn_metadata.num_prefills,
                seq_starts=prefill_metadata.chunked_context.starts[i],
            )

            kv_c_normed = workspace[:toks]\
                [..., :self.kv_lora_rank]
            k_pe = workspace[:toks]\
                [..., self.kv_lora_rank:].unsqueeze(1)

            kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
                -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = kv_nope\
                .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

            k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
                          dim=-1)

            attn_output, attn_softmax_lse = \
                self._flash_attn_varlen_diff_headdims(
                q=q,
                k=k,
                v=v,
                cu_seqlens_q=prefill_metadata.query_start_loc,
                cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
                max_seqlen_q=prefill_metadata.max_query_len,
                max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
                softmax_scale=self.scale,
                causal=False,  # Context is unmasked
                return_softmax_lse=True,
            )

            if output is None:
                output = attn_output
                output_lse = attn_softmax_lse
            else:
                output_tmp = torch.empty_like(output)
                output_lse_tmp = torch.empty_like(output_lse)
                merge_attn_states(
                    output=output_tmp,
                    output_lse=output_lse_tmp,
                    prefix_output=output,
                    prefix_lse=output_lse,
                    suffix_output=attn_output,
                    suffix_lse=attn_softmax_lse,
                )
                output = output_tmp
                output_lse = output_lse_tmp

        return output, output_lse

    def _forward_prefill(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
    ) -> torch.Tensor:
        assert attn_metadata.prefill is not None

        has_context = attn_metadata.prefill.chunked_context is not None
        kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv_nope\
            .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

        output = self._flash_attn_varlen_diff_headdims(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=attn_metadata.prefill.query_start_loc,
            cu_seqlens_k=attn_metadata.prefill.query_start_loc,
            max_seqlen_q=attn_metadata.prefill.max_query_len,
            max_seqlen_k=attn_metadata.prefill.max_query_len,
            softmax_scale=self.scale,
            causal=True,
            return_softmax_lse=has_context,
        )

        if has_context:
            suffix_output, suffix_lse = output
            context_output, context_lse = self._compute_prefill_context( \
                q, kv_c_and_k_pe_cache, attn_metadata)

            output = torch.empty_like(suffix_output)
            merge_attn_states(
                output=output,
                prefix_output=context_output,
                prefix_lse=context_lse,
                suffix_output=suffix_output,
                suffix_lse=suffix_lse,
            )

        # unpad if necessary
        if self._pad_v:
            output = output[..., :v.shape[-1]]

        return output.flatten(start_dim=-2)

    @abstractmethod
    def _forward_decode(
        self,
        ql_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: M,
    ) -> torch.Tensor:
        raise NotImplementedError

    def forward(
        self,
        layer: AttentionLayer,
        q: torch.Tensor,
        k_c_normed: torch.Tensor,  # key in unified attn
        k_pe: torch.Tensor,  # value in unified attn
        kv_cache: torch.Tensor,
        attn_metadata: M,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        assert output is not None, "Output tensor must be provided."

        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for MLACommonImpl")

        if attn_metadata is None:
            # The zero fill is required when used with DP + EP
            # to ensure all ranks within a DP group compute the
            # same expert outputs.
            return output.fill_(0)

        num_actual_toks = attn_metadata.num_actual_tokens

        # Inputs and outputs may be padded for CUDA graphs
        output_padded = output
        output = output[:num_actual_toks, ...]
        q = q[:num_actual_toks, ...]
        k_c_normed = k_c_normed[:num_actual_toks, ...]
        k_pe = k_pe[:num_actual_toks, ...]

        assert attn_metadata.num_decodes is not None and \
            attn_metadata.num_prefills is not None and \
            attn_metadata.num_decode_tokens is not None

        has_decode = attn_metadata.num_decodes > 0
        has_prefill = attn_metadata.num_prefills > 0
        num_decode_tokens = attn_metadata.num_decode_tokens

        decode_q = q[:num_decode_tokens]

        prefill_q = q[num_decode_tokens:]
        prefill_k_pe = k_pe[num_decode_tokens:]
        prefill_k_c_normed = k_c_normed[num_decode_tokens:]

        # write the latent and rope to kv cache
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                k_c_normed,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=self.kv_cache_dtype,
                scale=layer._k_scale,
            )

        if has_prefill:
            output[num_decode_tokens:] = self._forward_prefill(
                prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
                attn_metadata)

        if has_decode:
            assert attn_metadata.decode is not None
            decode_q_nope, decode_q_pe = decode_q.split(
                [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
            # Convert from (B, N, P) to (N, B, P)
            decode_q_nope = decode_q_nope.transpose(0, 1)
            # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
            decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
            # Convert from (N, B, L) to (B, N, L)
            decode_ql_nope = decode_ql_nope.transpose(0, 1)

            output[:num_decode_tokens] = self._forward_decode(
                decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)

        return output_padded

_pad_v instance-attribute

_pad_v = (
    vllm_flash_attn_version is None
    or not vllm_flash_attn_version == 3
    and get_device_capability()[0] == 9
)

flash_attn_varlen_func instance-attribute

flash_attn_varlen_func = flash_attn_varlen_func

head_size instance-attribute

head_size = head_size

kv_b_proj instance-attribute

kv_b_proj = kv_b_proj

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_lora_rank instance-attribute

kv_lora_rank = kv_lora_rank

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

q_lora_rank instance-attribute

q_lora_rank = q_lora_rank

qk_head_dim instance-attribute

qk_head_dim = qk_head_dim

qk_nope_head_dim instance-attribute

qk_nope_head_dim = qk_nope_head_dim

qk_rope_head_dim instance-attribute

qk_rope_head_dim = qk_rope_head_dim

scale instance-attribute

scale = float(scale)

v_head_dim instance-attribute

v_head_dim = v_head_dim

vllm_flash_attn_version instance-attribute

vllm_flash_attn_version = get_flash_attn_version()

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]],
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
) -> None
Source code in vllm/v1/attention/backends/mla/common.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]],
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    # MLA Specific Arguments
    q_lora_rank: Optional[int],
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    qk_head_dim: int,
    v_head_dim: int,
    kv_b_proj: ColumnParallelLinear,
) -> None:
    if kv_sharing_target_layer_name is not None:
        raise NotImplementedError("KV sharing is not supported for MLA")

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    self.kv_cache_dtype = kv_cache_dtype

    self.q_lora_rank = q_lora_rank
    self.kv_lora_rank = kv_lora_rank
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.qk_head_dim = qk_head_dim
    self.v_head_dim = v_head_dim
    self.kv_b_proj = kv_b_proj

    # Handle the differences between the flash_attn_varlen from flash_attn
    # and the one from vllm_flash_attn. The former is used on RoCM and the
    # latter has an additional parameter to control FA2 vs FA3
    self.flash_attn_varlen_func = flash_attn_varlen_func
    self.vllm_flash_attn_version = get_flash_attn_version()
    if self.vllm_flash_attn_version is not None:
        self.flash_attn_varlen_func = \
            functools.partial(flash_attn_varlen_func,
                              fa_version=self.vllm_flash_attn_version)

    # For MLA the v head dim is smaller than qk head dim so we pad out
    # v with 0s to match the qk head dim for attention backends that do
    # not support different headdims
    # We don't need to pad V if we are on a hopper system with FA3
    self._pad_v = self.vllm_flash_attn_version is None or not (
        self.vllm_flash_attn_version == 3
        and current_platform.get_device_capability()[0] == 9)

_compute_prefill_context

_compute_prefill_context(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: MLACommonMetadata,
)
Source code in vllm/v1/attention/backends/mla/common.py
def _compute_prefill_context(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: MLACommonMetadata,
):
    assert attn_metadata.prefill is not None
    prefill_metadata = attn_metadata.prefill
    assert prefill_metadata.chunked_context is not None

    output = None
    iters = len(prefill_metadata.chunked_context.seq_tot)
    workspace = prefill_metadata.chunked_context.workspace

    for i in range(iters):
        toks = prefill_metadata.chunked_context.seq_tot[i]

        ops.gather_cache(
            src_cache=kv_c_and_k_pe_cache,
            dst=workspace,
            block_table=prefill_metadata.block_table,
            cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
            batch_size=attn_metadata.num_prefills,
            seq_starts=prefill_metadata.chunked_context.starts[i],
        )

        kv_c_normed = workspace[:toks]\
            [..., :self.kv_lora_rank]
        k_pe = workspace[:toks]\
            [..., self.kv_lora_rank:].unsqueeze(1)

        kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv_nope\
            .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
                      dim=-1)

        attn_output, attn_softmax_lse = \
            self._flash_attn_varlen_diff_headdims(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=prefill_metadata.query_start_loc,
            cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
            max_seqlen_q=prefill_metadata.max_query_len,
            max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
            softmax_scale=self.scale,
            causal=False,  # Context is unmasked
            return_softmax_lse=True,
        )

        if output is None:
            output = attn_output
            output_lse = attn_softmax_lse
        else:
            output_tmp = torch.empty_like(output)
            output_lse_tmp = torch.empty_like(output_lse)
            merge_attn_states(
                output=output_tmp,
                output_lse=output_lse_tmp,
                prefix_output=output,
                prefix_lse=output_lse,
                suffix_output=attn_output,
                suffix_lse=attn_softmax_lse,
            )
            output = output_tmp
            output_lse = output_lse_tmp

    return output, output_lse

_flash_attn_varlen_diff_headdims

_flash_attn_varlen_diff_headdims(
    q,
    k,
    v,
    return_softmax_lse=False,
    softmax_scale=None,
    **kwargs,
)
Source code in vllm/v1/attention/backends/mla/common.py
def _flash_attn_varlen_diff_headdims(self,
                                     q,
                                     k,
                                     v,
                                     return_softmax_lse=False,
                                     softmax_scale=None,
                                     **kwargs):
    maybe_padded_v = v
    if self._pad_v:
        maybe_padded_v = torch.nn.functional.pad(
            v, [0, q.shape[-1] - v.shape[-1]], value=0)

    if is_vllm_fa:
        kwargs["return_softmax_lse"] = return_softmax_lse
    else:
        # ROCm leverages the upstream flash_attn, which takes a parameter
        # called "return_attn_probs" instead of return_softmax_lse
        kwargs["return_attn_probs"] = return_softmax_lse

    attn_out = self.flash_attn_varlen_func(
        q=q,
        k=k,
        v=maybe_padded_v,
        softmax_scale=softmax_scale,
        **kwargs,
    )

    # Unpack the output if there is multiple results
    lse = None
    if isinstance(attn_out, tuple):
        attn_out, lse = attn_out[0], attn_out[1]

    # Remain consistent with old `flash_attn_varlen_func` where there
    # is only one output tensor if `return_softmax_lse` is False.
    if return_softmax_lse:
        return attn_out, lse
    return attn_out

_forward_decode abstractmethod

_forward_decode(
    ql_nope: Tensor,
    q_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: M,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/common.py
@abstractmethod
def _forward_decode(
    self,
    ql_nope: torch.Tensor,
    q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: M,
) -> torch.Tensor:
    raise NotImplementedError

_forward_prefill

_forward_prefill(
    q: Tensor,
    kv_c_normed: Tensor,
    k_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: MLACommonMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/common.py
def _forward_prefill(
    self,
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
    assert attn_metadata.prefill is not None

    has_context = attn_metadata.prefill.chunked_context is not None
    kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
        -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
    k_nope, v = kv_nope\
        .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

    k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

    output = self._flash_attn_varlen_diff_headdims(
        q=q,
        k=k,
        v=v,
        cu_seqlens_q=attn_metadata.prefill.query_start_loc,
        cu_seqlens_k=attn_metadata.prefill.query_start_loc,
        max_seqlen_q=attn_metadata.prefill.max_query_len,
        max_seqlen_k=attn_metadata.prefill.max_query_len,
        softmax_scale=self.scale,
        causal=True,
        return_softmax_lse=has_context,
    )

    if has_context:
        suffix_output, suffix_lse = output
        context_output, context_lse = self._compute_prefill_context( \
            q, kv_c_and_k_pe_cache, attn_metadata)

        output = torch.empty_like(suffix_output)
        merge_attn_states(
            output=output,
            prefix_output=context_output,
            prefix_lse=context_lse,
            suffix_output=suffix_output,
            suffix_lse=suffix_lse,
        )

    # unpad if necessary
    if self._pad_v:
        output = output[..., :v.shape[-1]]

    return output.flatten(start_dim=-2)

_v_up_proj

_v_up_proj(x)
Source code in vllm/v1/attention/backends/mla/common.py
def _v_up_proj(self, x):
    # Convert from (B, N, L) to (N, B, L)
    x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
    # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
    x = torch.bmm(x, self.W_UV)
    # Convert from (N, B, V) to (B, N * V)
    return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)

forward

forward(
    layer: AttentionLayer,
    q: Tensor,
    k_c_normed: Tensor,
    k_pe: Tensor,
    kv_cache: Tensor,
    attn_metadata: M,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/common.py
def forward(
    self,
    layer: AttentionLayer,
    q: torch.Tensor,
    k_c_normed: torch.Tensor,  # key in unified attn
    k_pe: torch.Tensor,  # value in unified attn
    kv_cache: torch.Tensor,
    attn_metadata: M,
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:

    assert output is not None, "Output tensor must be provided."

    if output_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported"
            " for MLACommonImpl")

    if attn_metadata is None:
        # The zero fill is required when used with DP + EP
        # to ensure all ranks within a DP group compute the
        # same expert outputs.
        return output.fill_(0)

    num_actual_toks = attn_metadata.num_actual_tokens

    # Inputs and outputs may be padded for CUDA graphs
    output_padded = output
    output = output[:num_actual_toks, ...]
    q = q[:num_actual_toks, ...]
    k_c_normed = k_c_normed[:num_actual_toks, ...]
    k_pe = k_pe[:num_actual_toks, ...]

    assert attn_metadata.num_decodes is not None and \
        attn_metadata.num_prefills is not None and \
        attn_metadata.num_decode_tokens is not None

    has_decode = attn_metadata.num_decodes > 0
    has_prefill = attn_metadata.num_prefills > 0
    num_decode_tokens = attn_metadata.num_decode_tokens

    decode_q = q[:num_decode_tokens]

    prefill_q = q[num_decode_tokens:]
    prefill_k_pe = k_pe[num_decode_tokens:]
    prefill_k_c_normed = k_c_normed[num_decode_tokens:]

    # write the latent and rope to kv cache
    if kv_cache.numel() > 0:
        ops.concat_and_cache_mla(
            k_c_normed,
            k_pe.squeeze(1),
            kv_cache,
            attn_metadata.slot_mapping.flatten(),
            kv_cache_dtype=self.kv_cache_dtype,
            scale=layer._k_scale,
        )

    if has_prefill:
        output[num_decode_tokens:] = self._forward_prefill(
            prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
            attn_metadata)

    if has_decode:
        assert attn_metadata.decode is not None
        decode_q_nope, decode_q_pe = decode_q.split(
            [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        # Convert from (B, N, P) to (N, B, P)
        decode_q_nope = decode_q_nope.transpose(0, 1)
        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
        # Convert from (N, B, L) to (B, N, L)
        decode_ql_nope = decode_ql_nope.transpose(0, 1)

        output[:num_decode_tokens] = self._forward_decode(
            decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)

    return output_padded

process_weights_after_loading

process_weights_after_loading(act_dtype: dtype)
Source code in vllm/v1/attention/backends/mla/common.py
def process_weights_after_loading(self, act_dtype: torch.dtype):

    def get_layer_weight(layer):
        WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
        for attr in WEIGHT_NAMES:
            if hasattr(layer, attr):
                return getattr(layer, attr)
        raise AttributeError(
            f"Layer '{layer}' has no recognized weight attribute:"
            f" {WEIGHT_NAMES}.")

    def get_and_maybe_dequant_weights(layer: LinearBase):
        if not isinstance(layer.quant_method, UnquantizedLinearMethod):
            # NOTE: This should only be used offline, since it's O(N^3)
            eye = torch.eye(layer.input_size_per_partition,
                            dtype=act_dtype,
                            device=get_layer_weight(layer).device)
            dequant_weights = layer.quant_method.apply(layer,
                                                       eye,
                                                       bias=None)
            del eye
            # standardize to (output, input)
            return dequant_weights.T
        return layer.weight

    # we currently do not have quantized bmm's which are needed for
    # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
    # the bmm's in 16-bit, the extra memory overhead of this is fairly low
    kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
    assert kv_b_proj_weight.shape == (
        self.kv_lora_rank,
        self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
            f"{kv_b_proj_weight.shape=}, "
            f"{self.kv_lora_rank=}, "
            f"{self.num_heads=}, "
            f"{self.qk_nope_head_dim=}, "
            f"{self.v_head_dim=}")
    kv_b_proj_weight = kv_b_proj_weight.view(
        self.kv_lora_rank,
        self.num_heads,
        self.qk_nope_head_dim + self.v_head_dim,
    )

    W_UK, W_UV = kv_b_proj_weight.split(
        [self.qk_nope_head_dim, self.v_head_dim], dim=-1)

    # Convert from (L, N, V) to (N, L, V)
    self.W_UV = W_UV.transpose(0, 1)
    # Convert from (L, N, P) to (N, P, L)
    self.W_UK_T = W_UK.permute(1, 2, 0)

MLACommonMetadata dataclass

Bases: Generic[D]

Metadata for MLACommon.

NOTE: Please read the comment at the top of the file before trying to understand this class

Source code in vllm/v1/attention/backends/mla/common.py
@dataclass
class MLACommonMetadata(Generic[D]):
    """Metadata for MLACommon.

    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    # New for MLA (compared to FlashAttention)
    # For handling prefill decode split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int

    # The dimension of the attention heads
    head_dim: Optional[int] = None

    decode: Optional[D] = None
    prefill: Optional[MLACommonPrefillMetadata] = None

    def __post_init__(self):
        supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
        if self.head_dim is not None and self.head_dim \
                not in supported_head_sizes:
            raise ValueError(
                f"Only {supported_head_sizes} are supported for head_dim,",
                f"received {self.head_dim}.")

decode class-attribute instance-attribute

decode: Optional[D] = None

head_dim class-attribute instance-attribute

head_dim: Optional[int] = None

num_actual_tokens instance-attribute

num_actual_tokens: int

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_prefills instance-attribute

num_prefills: int

prefill class-attribute instance-attribute

query_start_loc instance-attribute

query_start_loc: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

__init__

__init__(
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    head_dim: Optional[int] = None,
    decode: Optional[D] = None,
    prefill: Optional[MLACommonPrefillMetadata] = None,
) -> None

__post_init__

__post_init__()
Source code in vllm/v1/attention/backends/mla/common.py
def __post_init__(self):
    supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
    if self.head_dim is not None and self.head_dim \
            not in supported_head_sizes:
        raise ValueError(
            f"Only {supported_head_sizes} are supported for head_dim,",
            f"received {self.head_dim}.")

MLACommonMetadataBuilder

Bases: AttentionMetadataBuilder[M]

NOTE: Please read the comment at the top of the file before trying to understand this class

Source code in vllm/v1/attention/backends/mla/common.py
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
    """
    NOTE: Please read the comment at the top of the file before trying to
    understand this class
    """

    def __init__(self,
                 runner: "GPUModelRunner",
                 kv_cache_spec: AttentionSpec,
                 block_table: BlockTable,
                 metadata_cls: Optional[type[M]] = None):
        self.metadata_cls = metadata_cls \
            if metadata_cls is not None else MLACommonMetadata
        self.runner = runner
        scheduler_config = runner.scheduler_config
        model_config = runner.model_config
        cache_config = runner.cache_config
        self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
        self.num_heads = model_config.get_num_attention_heads(
            runner.parallel_config)
        self.mla_dims = get_mla_dims(model_config)
        self.aot_schedule = current_platform.is_cuda()
        self.kv_cache_spec = kv_cache_spec

        # Dont try to access the runner on AMD
        if self.aot_schedule:
            self.page_size = self.kv_cache_spec.block_size

        if self.chunked_prefill_enabled:
            self.chunked_prefill_workspace_size = min(
                # Max sure there is enough for 8 full length request or at least
                # 4 pages of cache per request
                max(
                    8 * model_config.max_model_len, 4 *
                    scheduler_config.max_num_seqs * cache_config.block_size),
                # For long-context models try not to over-allocate limiting
                # kv-cache space, limiting it to 64k tokens,
                # which would result in the workspace being:
                #   2*(576)*(64*1024) = 144mb
                # (assuming 576 MLA head dim, and fp16)
                # which would result in up-projected context being
                #   2*(192*128)*(64*1024) = 3gb
                # (assuming 192 QK head dim, 128 heads, and fp16)
                128 * 1024)
            assert self.chunked_prefill_workspace_size >= \
                scheduler_config.max_num_seqs * cache_config.block_size
            self.chunked_prefill_workspace = torch.empty(
                (self.chunked_prefill_workspace_size,
                 model_config.get_head_size()),
                dtype=model_config.dtype,
                device=runner.device,
            )
        self.block_table = block_table

    def reorder_batch(self, input_batch: "InputBatch",
                      scheduler_output: "SchedulerOutput") -> bool:
        # We now want to reorder the batch so that the "decode" requests are and
        # the front and the "prefill" requests are at the using the least amount
        # swaps possible. (NOTE for now we loosely use "decode" to mean requests
        # where attention is likely memory-bound and "prefill" to mean requests
        # where attention is likely compute-bound, TODO(lucas): figure out a
        # better naming here)
        decodes = []
        prefills = []
        num_decode_tokens = 0
        num_prefill_tokens = 0

        for i, req_id in enumerate(input_batch.req_ids):
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
            # for now treat 1 scheduled token as "decode" even if its not,
            # we should update this to something like < 8 in the future but
            # currently the TritonMLA._forward_decode only supports
            # num_tokens = 1
            if num_tokens == 1:
                decodes.append(i)
                num_decode_tokens += num_tokens
            else:
                prefills.append(i)
                num_prefill_tokens += num_tokens

        # We hope that this is fairly minimal since decodes
        # should be around for a number of iterations so hopefully they are
        # relatively stationary (and new request are generally appended to the
        # persistent batch so already should be at the back)
        # To achieve this we loop over the decodes in descending order and
        # the prefills in ascending order. We swap decodes from the  "back"
        # i.e. past where the last decode should be in the reodorered with
        # prefills from the front of the batch.
        # `decodes` and `prefills` are already in ascending order just based on
        # the above loop
        num_decodes = len(decodes)
        num_prefills = len(prefills)
        modified_batch = False

        for i in range(1, min(num_decodes, num_prefills) + 1):
            # If the decode is at the "back" of the batch, i, we can swap it
            # with the prefill closest to the front of the batch
            decode_idx = decodes[num_decodes - i]
            if decode_idx < num_decodes:
                break

            input_batch.swap_states(prefills[i - 1], decode_idx)
            modified_batch = True

        # Save for next `build` call
        # TODO(lucas): this is a bit of a hack, we should probably have a
        # better way of doing this
        self._num_decodes = num_decodes
        self._num_prefills = num_prefills
        self._num_decode_tokens = num_decode_tokens
        self._num_prefill_tokens = num_prefill_tokens

        return modified_batch

    def _build_decode(self, block_table_tensor: torch.Tensor,
                      seq_lens: torch.Tensor):
        return MLACommonDecodeMetadata(
            block_table=block_table_tensor,
            seq_lens=seq_lens,
        )

    def build_for_cudagraph_capture(
            self, common_attn_metadata: CommonAttentionMetadata) -> M:
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with MLA.
        """
        m = common_attn_metadata
        assert m.num_reqs == m.num_actual_tokens, \
            "MLA only supports decode-only full CUDAGraph capture. " \
            "Make sure all cudagraph capture sizes <= max_num_seq."

        m.max_query_len = 1  # decode-only

        # Update state usually set in reorder_batch.
        self._num_decodes = m.num_reqs
        self._num_decode_tokens = m.num_actual_tokens
        self._num_prefills = 0
        self._num_prefill_tokens = 0
        return self.build(0, m)

    def build(self, common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata) -> M:
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len

        assert self._num_decodes + self._num_prefills == num_reqs

        # Note(simon): be careful about the CPU <> GPU memory movement in this
        # function. We should avoid GPU -> CPU sync as much as possible because
        # it blocks on all previous kernels.
        device = self.runner.device
        block_table = self.block_table
        block_table_tensor = block_table.get_device_tensor()[:num_reqs]
        block_table.slot_mapping[:num_actual_tokens].copy_(
            block_table.slot_mapping_cpu[:num_actual_tokens],
            non_blocking=True)
        block_table.slot_mapping[num_actual_tokens:].fill_(-1)
        slot_mapping = block_table.slot_mapping[:num_actual_tokens]

        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens

        prefill_metadata = None
        if self._num_prefills > 0:
            reqs_start = self._num_decodes  # prefill_start

            context_lens_cpu = self.runner.input_batch.\
                num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
            max_context_len_cpu = context_lens_cpu.max().item()
            num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
            prefill_query_start_loc = query_start_loc[
                reqs_start:] - query_start_loc[reqs_start]

            chunked_context_metadata = None
            if self.chunked_prefill_enabled and self._num_prefills > 0 \
                and max_context_len_cpu > 0:
                # NOTE: it is recommend you read the `Chunked Prefill` section
                # in the comment at the top of the file before trying to
                # understand the following code

                # currently we allocate an equal amount of workspace for each
                # prefill in the batch, we could probably use a more advanced
                # algorithm here and allocate more workspace to prefills with
                # longer context lengths
                max_context_chunk = (self.chunked_prefill_workspace_size //
                                     num_prefills_with_context_cpu)

                if self.aot_schedule:
                    # align max_context_chunk to page_size by rounding down,
                    # currently the `gather_cache` kernel cannot handle
                    # `context_chunk_starts` that are not aligned to page_size
                    max_context_chunk = round_down(max_context_chunk,
                                                   self.page_size)

                assert max_context_chunk > 0
                num_chunks = cdiv(max_context_len_cpu, max_context_chunk)

                # if `max_context_chunk = 256`, `num_chunks = 3`, and
                #   `num_prefills_with_context = 4`, create a tensor that looks
                # like
                #  [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
                # Note(simon): this is done in CPU because of downstream's
                # of `to_list`.
                chunk_starts = \
                    torch.arange(num_chunks, dtype=torch.int32) \
                    .unsqueeze(1).expand(-1, self._num_prefills) \
                    * max_context_chunk
                chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
                                       chunk_starts + max_context_chunk)
                chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)

                cu_seq_lens_cpu = torch.zeros(num_chunks,
                                              self._num_prefills + 1,
                                              dtype=torch.int32,
                                              pin_memory=True)
                torch.cumsum(chunk_seq_lens,
                             dim=1,
                             out=cu_seq_lens_cpu[:, 1:],
                             dtype=torch.int32)

                chunked_context_metadata = \
                    MLACommonPrefillMetadata.ChunkedContextMetadata(
                    cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
                    starts=chunk_starts.to(device, non_blocking=True),
                    seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
                    max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
                    workspace=self.chunked_prefill_workspace,
                )

                assert max(chunked_context_metadata.max_seq_lens) <= \
                    self.chunked_prefill_workspace_size

            prefill_metadata = MLACommonPrefillMetadata(
                block_table=block_table_tensor[reqs_start:, ...],
                query_start_loc=prefill_query_start_loc,
                max_query_len=max_query_len,
                chunked_context=chunked_context_metadata,
            )

        decode_metadata = None
        if self._num_decodes > 0:
            decode_metadata = self._build_decode(
                block_table_tensor=block_table_tensor[:self._num_decodes, ...],
                seq_lens=seq_lens[:self._num_decodes],
            )

        return self.metadata_cls(
            num_actual_tokens=num_actual_tokens,
            query_start_loc=query_start_loc,
            slot_mapping=slot_mapping,
            head_dim=self.runner.model_config.get_head_size(),
            # MLACommonMetadata Chunk prefill specific
            num_decodes=self._num_decodes,
            num_decode_tokens=self._num_decode_tokens,
            num_prefills=self._num_prefills,
            prefill=prefill_metadata,
            decode=decode_metadata,
        )

    def can_run_in_cudagraph(
            self, common_attn_metadata: CommonAttentionMetadata) -> bool:
        return common_attn_metadata.max_query_len == 1

aot_schedule instance-attribute

aot_schedule = is_cuda()

block_table instance-attribute

block_table = block_table

chunked_prefill_enabled instance-attribute

chunked_prefill_enabled = chunked_prefill_enabled

chunked_prefill_workspace instance-attribute

chunked_prefill_workspace = empty(
    (chunked_prefill_workspace_size, get_head_size()),
    dtype=dtype,
    device=device,
)

chunked_prefill_workspace_size instance-attribute

chunked_prefill_workspace_size = min(
    max(8 * max_model_len, 4 * max_num_seqs * block_size),
    128 * 1024,
)

kv_cache_spec instance-attribute

kv_cache_spec = kv_cache_spec

metadata_cls instance-attribute

metadata_cls = (
    metadata_cls
    if metadata_cls is not None
    else MLACommonMetadata
)

mla_dims instance-attribute

mla_dims = get_mla_dims(model_config)

num_heads instance-attribute

num_heads = get_num_attention_heads(parallel_config)

page_size instance-attribute

page_size = block_size

runner instance-attribute

runner = runner

__init__

__init__(
    runner: GPUModelRunner,
    kv_cache_spec: AttentionSpec,
    block_table: BlockTable,
    metadata_cls: Optional[type[M]] = None,
)
Source code in vllm/v1/attention/backends/mla/common.py
def __init__(self,
             runner: "GPUModelRunner",
             kv_cache_spec: AttentionSpec,
             block_table: BlockTable,
             metadata_cls: Optional[type[M]] = None):
    self.metadata_cls = metadata_cls \
        if metadata_cls is not None else MLACommonMetadata
    self.runner = runner
    scheduler_config = runner.scheduler_config
    model_config = runner.model_config
    cache_config = runner.cache_config
    self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
    self.num_heads = model_config.get_num_attention_heads(
        runner.parallel_config)
    self.mla_dims = get_mla_dims(model_config)
    self.aot_schedule = current_platform.is_cuda()
    self.kv_cache_spec = kv_cache_spec

    # Dont try to access the runner on AMD
    if self.aot_schedule:
        self.page_size = self.kv_cache_spec.block_size

    if self.chunked_prefill_enabled:
        self.chunked_prefill_workspace_size = min(
            # Max sure there is enough for 8 full length request or at least
            # 4 pages of cache per request
            max(
                8 * model_config.max_model_len, 4 *
                scheduler_config.max_num_seqs * cache_config.block_size),
            # For long-context models try not to over-allocate limiting
            # kv-cache space, limiting it to 64k tokens,
            # which would result in the workspace being:
            #   2*(576)*(64*1024) = 144mb
            # (assuming 576 MLA head dim, and fp16)
            # which would result in up-projected context being
            #   2*(192*128)*(64*1024) = 3gb
            # (assuming 192 QK head dim, 128 heads, and fp16)
            128 * 1024)
        assert self.chunked_prefill_workspace_size >= \
            scheduler_config.max_num_seqs * cache_config.block_size
        self.chunked_prefill_workspace = torch.empty(
            (self.chunked_prefill_workspace_size,
             model_config.get_head_size()),
            dtype=model_config.dtype,
            device=runner.device,
        )
    self.block_table = block_table

_build_decode

_build_decode(block_table_tensor: Tensor, seq_lens: Tensor)
Source code in vllm/v1/attention/backends/mla/common.py
def _build_decode(self, block_table_tensor: torch.Tensor,
                  seq_lens: torch.Tensor):
    return MLACommonDecodeMetadata(
        block_table=block_table_tensor,
        seq_lens=seq_lens,
    )

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
) -> M
Source code in vllm/v1/attention/backends/mla/common.py
def build(self, common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata) -> M:
    num_reqs = common_attn_metadata.num_reqs
    num_actual_tokens = common_attn_metadata.num_actual_tokens
    max_query_len = common_attn_metadata.max_query_len

    assert self._num_decodes + self._num_prefills == num_reqs

    # Note(simon): be careful about the CPU <> GPU memory movement in this
    # function. We should avoid GPU -> CPU sync as much as possible because
    # it blocks on all previous kernels.
    device = self.runner.device
    block_table = self.block_table
    block_table_tensor = block_table.get_device_tensor()[:num_reqs]
    block_table.slot_mapping[:num_actual_tokens].copy_(
        block_table.slot_mapping_cpu[:num_actual_tokens],
        non_blocking=True)
    block_table.slot_mapping[num_actual_tokens:].fill_(-1)
    slot_mapping = block_table.slot_mapping[:num_actual_tokens]

    query_start_loc = common_attn_metadata.query_start_loc
    seq_lens = common_attn_metadata.seq_lens

    prefill_metadata = None
    if self._num_prefills > 0:
        reqs_start = self._num_decodes  # prefill_start

        context_lens_cpu = self.runner.input_batch.\
            num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
        max_context_len_cpu = context_lens_cpu.max().item()
        num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
        prefill_query_start_loc = query_start_loc[
            reqs_start:] - query_start_loc[reqs_start]

        chunked_context_metadata = None
        if self.chunked_prefill_enabled and self._num_prefills > 0 \
            and max_context_len_cpu > 0:
            # NOTE: it is recommend you read the `Chunked Prefill` section
            # in the comment at the top of the file before trying to
            # understand the following code

            # currently we allocate an equal amount of workspace for each
            # prefill in the batch, we could probably use a more advanced
            # algorithm here and allocate more workspace to prefills with
            # longer context lengths
            max_context_chunk = (self.chunked_prefill_workspace_size //
                                 num_prefills_with_context_cpu)

            if self.aot_schedule:
                # align max_context_chunk to page_size by rounding down,
                # currently the `gather_cache` kernel cannot handle
                # `context_chunk_starts` that are not aligned to page_size
                max_context_chunk = round_down(max_context_chunk,
                                               self.page_size)

            assert max_context_chunk > 0
            num_chunks = cdiv(max_context_len_cpu, max_context_chunk)

            # if `max_context_chunk = 256`, `num_chunks = 3`, and
            #   `num_prefills_with_context = 4`, create a tensor that looks
            # like
            #  [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
            # Note(simon): this is done in CPU because of downstream's
            # of `to_list`.
            chunk_starts = \
                torch.arange(num_chunks, dtype=torch.int32) \
                .unsqueeze(1).expand(-1, self._num_prefills) \
                * max_context_chunk
            chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
                                   chunk_starts + max_context_chunk)
            chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)

            cu_seq_lens_cpu = torch.zeros(num_chunks,
                                          self._num_prefills + 1,
                                          dtype=torch.int32,
                                          pin_memory=True)
            torch.cumsum(chunk_seq_lens,
                         dim=1,
                         out=cu_seq_lens_cpu[:, 1:],
                         dtype=torch.int32)

            chunked_context_metadata = \
                MLACommonPrefillMetadata.ChunkedContextMetadata(
                cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
                starts=chunk_starts.to(device, non_blocking=True),
                seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
                max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
                workspace=self.chunked_prefill_workspace,
            )

            assert max(chunked_context_metadata.max_seq_lens) <= \
                self.chunked_prefill_workspace_size

        prefill_metadata = MLACommonPrefillMetadata(
            block_table=block_table_tensor[reqs_start:, ...],
            query_start_loc=prefill_query_start_loc,
            max_query_len=max_query_len,
            chunked_context=chunked_context_metadata,
        )

    decode_metadata = None
    if self._num_decodes > 0:
        decode_metadata = self._build_decode(
            block_table_tensor=block_table_tensor[:self._num_decodes, ...],
            seq_lens=seq_lens[:self._num_decodes],
        )

    return self.metadata_cls(
        num_actual_tokens=num_actual_tokens,
        query_start_loc=query_start_loc,
        slot_mapping=slot_mapping,
        head_dim=self.runner.model_config.get_head_size(),
        # MLACommonMetadata Chunk prefill specific
        num_decodes=self._num_decodes,
        num_decode_tokens=self._num_decode_tokens,
        num_prefills=self._num_prefills,
        prefill=prefill_metadata,
        decode=decode_metadata,
    )

build_for_cudagraph_capture

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
) -> M

This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with MLA.

Source code in vllm/v1/attention/backends/mla/common.py
def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata) -> M:
    """
    This method builds the metadata for full cudagraph capture.
    Currently, only decode is supported for full cudagraphs with MLA.
    """
    m = common_attn_metadata
    assert m.num_reqs == m.num_actual_tokens, \
        "MLA only supports decode-only full CUDAGraph capture. " \
        "Make sure all cudagraph capture sizes <= max_num_seq."

    m.max_query_len = 1  # decode-only

    # Update state usually set in reorder_batch.
    self._num_decodes = m.num_reqs
    self._num_decode_tokens = m.num_actual_tokens
    self._num_prefills = 0
    self._num_prefill_tokens = 0
    return self.build(0, m)

can_run_in_cudagraph

can_run_in_cudagraph(
    common_attn_metadata: CommonAttentionMetadata,
) -> bool
Source code in vllm/v1/attention/backends/mla/common.py
def can_run_in_cudagraph(
        self, common_attn_metadata: CommonAttentionMetadata) -> bool:
    return common_attn_metadata.max_query_len == 1

reorder_batch

reorder_batch(
    input_batch: InputBatch,
    scheduler_output: SchedulerOutput,
) -> bool
Source code in vllm/v1/attention/backends/mla/common.py
def reorder_batch(self, input_batch: "InputBatch",
                  scheduler_output: "SchedulerOutput") -> bool:
    # We now want to reorder the batch so that the "decode" requests are and
    # the front and the "prefill" requests are at the using the least amount
    # swaps possible. (NOTE for now we loosely use "decode" to mean requests
    # where attention is likely memory-bound and "prefill" to mean requests
    # where attention is likely compute-bound, TODO(lucas): figure out a
    # better naming here)
    decodes = []
    prefills = []
    num_decode_tokens = 0
    num_prefill_tokens = 0

    for i, req_id in enumerate(input_batch.req_ids):
        num_tokens = scheduler_output.num_scheduled_tokens[req_id]
        # for now treat 1 scheduled token as "decode" even if its not,
        # we should update this to something like < 8 in the future but
        # currently the TritonMLA._forward_decode only supports
        # num_tokens = 1
        if num_tokens == 1:
            decodes.append(i)
            num_decode_tokens += num_tokens
        else:
            prefills.append(i)
            num_prefill_tokens += num_tokens

    # We hope that this is fairly minimal since decodes
    # should be around for a number of iterations so hopefully they are
    # relatively stationary (and new request are generally appended to the
    # persistent batch so already should be at the back)
    # To achieve this we loop over the decodes in descending order and
    # the prefills in ascending order. We swap decodes from the  "back"
    # i.e. past where the last decode should be in the reodorered with
    # prefills from the front of the batch.
    # `decodes` and `prefills` are already in ascending order just based on
    # the above loop
    num_decodes = len(decodes)
    num_prefills = len(prefills)
    modified_batch = False

    for i in range(1, min(num_decodes, num_prefills) + 1):
        # If the decode is at the "back" of the batch, i, we can swap it
        # with the prefill closest to the front of the batch
        decode_idx = decodes[num_decodes - i]
        if decode_idx < num_decodes:
            break

        input_batch.swap_states(prefills[i - 1], decode_idx)
        modified_batch = True

    # Save for next `build` call
    # TODO(lucas): this is a bit of a hack, we should probably have a
    # better way of doing this
    self._num_decodes = num_decodes
    self._num_prefills = num_prefills
    self._num_decode_tokens = num_decode_tokens
    self._num_prefill_tokens = num_prefill_tokens

    return modified_batch

MLACommonPrefillMetadata dataclass

Prefill Specific Metadata

Source code in vllm/v1/attention/backends/mla/common.py
@dataclass
class MLACommonPrefillMetadata:
    """ Prefill Specific Metadata """

    @dataclass
    class ChunkedContextMetadata:
        # New for MLA (compared to FlashAttention)
        # For handling chunked prefill
        cu_seq_lens: torch.Tensor
        starts: torch.Tensor
        seq_tot: list[int]
        max_seq_lens: list[int]
        workspace: torch.Tensor

    block_table: torch.Tensor
    query_start_loc: torch.Tensor
    max_query_len: int
    chunked_context: Optional[ChunkedContextMetadata] = None

block_table instance-attribute

block_table: Tensor

chunked_context class-attribute instance-attribute

chunked_context: Optional[ChunkedContextMetadata] = None

max_query_len instance-attribute

max_query_len: int

query_start_loc instance-attribute

query_start_loc: Tensor

ChunkedContextMetadata dataclass

Source code in vllm/v1/attention/backends/mla/common.py
@dataclass
class ChunkedContextMetadata:
    # New for MLA (compared to FlashAttention)
    # For handling chunked prefill
    cu_seq_lens: torch.Tensor
    starts: torch.Tensor
    seq_tot: list[int]
    max_seq_lens: list[int]
    workspace: torch.Tensor

cu_seq_lens instance-attribute

cu_seq_lens: Tensor

max_seq_lens instance-attribute

max_seq_lens: list[int]

seq_tot instance-attribute

seq_tot: list[int]

starts instance-attribute

starts: Tensor

workspace instance-attribute

workspace: Tensor

__init__

__init__(
    cu_seq_lens: Tensor,
    starts: Tensor,
    seq_tot: list[int],
    max_seq_lens: list[int],
    workspace: Tensor,
) -> None

__init__

__init__(
    block_table: Tensor,
    query_start_loc: Tensor,
    max_query_len: int,
    chunked_context: Optional[
        ChunkedContextMetadata
    ] = None,
) -> None