Skip to content

vllm.v1.attention.backends.mla.flashmla

logger module-attribute

logger = init_logger(__name__)

FlashMLABackend

Bases: MLACommonBackend

Source code in vllm/v1/attention/backends/mla/flashmla.py
class FlashMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "FLASHMLA_VLLM_V1"

    @staticmethod
    def get_metadata_cls() -> type["FlashMLAMetadata"]:
        return FlashMLAMetadata

    @staticmethod
    def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
        return FlashMLAMetadataBuilder

    @staticmethod
    def get_impl_cls() -> type["FlashMLAImpl"]:
        return FlashMLAImpl

get_builder_cls staticmethod

get_builder_cls() -> type[FlashMLAMetadataBuilder]
Source code in vllm/v1/attention/backends/mla/flashmla.py
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
    return FlashMLAMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[FlashMLAImpl]
Source code in vllm/v1/attention/backends/mla/flashmla.py
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
    return FlashMLAImpl

get_metadata_cls staticmethod

get_metadata_cls() -> type[FlashMLAMetadata]
Source code in vllm/v1/attention/backends/mla/flashmla.py
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
    return FlashMLAMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/mla/flashmla.py
@staticmethod
def get_name() -> str:
    return "FLASHMLA_VLLM_V1"

FlashMLADecodeMetadata dataclass

Bases: MLACommonDecodeMetadata

Source code in vllm/v1/attention/backends/mla/flashmla.py
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
    tile_scheduler_metadata: torch.Tensor
    num_splits: torch.Tensor

num_splits instance-attribute

num_splits: Tensor

tile_scheduler_metadata instance-attribute

tile_scheduler_metadata: Tensor

__init__

__init__(
    block_table: Tensor,
    seq_lens: Tensor,
    tile_scheduler_metadata: Tensor,
    num_splits: Tensor,
) -> None

FlashMLAImpl

Bases: MLACommonImpl[FlashMLAMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla.py
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[list[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            blocksparse_params: Optional[dict[str, Any]],
            logits_soft_cap: Optional[float],
            attn_type: str,
            kv_sharing_target_layer_name: Optional[str],
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
                         kv_sharing_target_layer_name, **mla_args)

        assert is_flashmla_supported(), \
            "FlashMLA is not supported on this device"

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "FlashMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashMLAImpl")

        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "FlashMLA V1 with FP8 KV cache not yet supported")

    def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLAMetadata,
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0
        assert attn_metadata.decode is not None

        q = torch.cat([q_nope, q_pe], dim=-1)\
            .unsqueeze(1) # Add seqlen dim of 1 (decode)

        o, _ = flash_mla_with_kvcache(
            q=q,
            k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
            block_table=attn_metadata.decode.block_table,
            cache_seqlens=attn_metadata.decode.seq_lens,
            head_dim_v=self.kv_lora_rank,
            tile_scheduler_metadata=attn_metadata.decode.
            tile_scheduler_metadata,
            num_splits=attn_metadata.decode.num_splits,
            softmax_scale=self.scale,
            causal=True,
        )

        return self._v_up_proj(o)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]],
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **mla_args,
) -> None
Source code in vllm/v1/attention/backends/mla/flashmla.py
def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[dict[str, Any]],
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        **mla_args) -> None:
    super().__init__(num_heads, head_size, scale, num_kv_heads,
                     alibi_slopes, sliding_window, kv_cache_dtype,
                     blocksparse_params, logits_soft_cap, attn_type,
                     kv_sharing_target_layer_name, **mla_args)

    assert is_flashmla_supported(), \
        "FlashMLA is not supported on this device"

    unsupported_features = [
        alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
    ]
    if any(unsupported_features):
        raise NotImplementedError(
            "FlashMLAImpl does not support one of the following: "
            "alibi_slopes, sliding_window, blocksparse_params, "
            "logits_soft_cap")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "FlashMLAImpl")

    if is_quantized_kv_cache(self.kv_cache_dtype):
        raise NotImplementedError(
            "FlashMLA V1 with FP8 KV cache not yet supported")

_forward_decode

_forward_decode(
    q_nope: Tensor,
    q_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: FlashMLAMetadata,
) -> Tensor
Source code in vllm/v1/attention/backends/mla/flashmla.py
def _forward_decode(
    self,
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
    assert kv_c_and_k_pe_cache.numel() > 0
    assert attn_metadata.decode is not None

    q = torch.cat([q_nope, q_pe], dim=-1)\
        .unsqueeze(1) # Add seqlen dim of 1 (decode)

    o, _ = flash_mla_with_kvcache(
        q=q,
        k_cache=kv_c_and_k_pe_cache.unsqueeze(-2),  # Add head dim of 1
        block_table=attn_metadata.decode.block_table,
        cache_seqlens=attn_metadata.decode.seq_lens,
        head_dim_v=self.kv_lora_rank,
        tile_scheduler_metadata=attn_metadata.decode.
        tile_scheduler_metadata,
        num_splits=attn_metadata.decode.num_splits,
        softmax_scale=self.scale,
        causal=True,
    )

    return self._v_up_proj(o)

FlashMLAMetadata dataclass

Bases: MLACommonMetadata[FlashMLADecodeMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla.py
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
    pass

__init__

__init__(
    num_actual_tokens: int,
    query_start_loc: Tensor,
    slot_mapping: Tensor,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    head_dim: Optional[int] = None,
    decode: Optional[D] = None,
    prefill: Optional[MLACommonPrefillMetadata] = None,
) -> None

FlashMLAMetadataBuilder

Bases: MLACommonMetadataBuilder[FlashMLAMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla.py
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
    full_cudagraph_supported: ClassVar[bool] = True  # Decode-only

    def __init__(self, runner, kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
        super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)

        self.num_q_heads = self.runner.model_config.get_num_attention_heads(
            self.runner.parallel_config)

        self.cg_buf_tile_scheduler_metadata = None
        self.cg_buf_num_splits = None

    def _build_decode(self, block_table_tensor: torch.Tensor,
                      seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
        tile_scheduler_metadata, num_splits = \
            get_mla_metadata(
            seq_lens,
            self.num_q_heads,
            1, # MQA for the decode path
        )

        if self.runner.full_cuda_graph:
            # First time around (CUDAGraph capture), allocate the static buffer
            if self.cg_buf_tile_scheduler_metadata is None:
                self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
                self.cg_buf_num_splits = num_splits
            else:
                assert self.cg_buf_num_splits is not None

                # Metadata per-SM, fixed size (#SMs, TileMetadataSize)
                assert (self.cg_buf_tile_scheduler_metadata.size() ==
                        tile_scheduler_metadata.size())
                self.cg_buf_tile_scheduler_metadata.\
                    copy_(tile_scheduler_metadata)
                tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata

                # Num splits is per-batch, varying size (batch_size,)
                n = num_splits.size(0)
                # make sure static buffer is large enough
                assert n <= self.cg_buf_num_splits.size(0)
                num_splits_view = self.cg_buf_num_splits[:n]
                num_splits_view.copy_(num_splits)
                self.cg_buf_num_splits[n:].fill_(0)  # fill the rest with 0s
                num_splits = num_splits_view

        return FlashMLADecodeMetadata(
            block_table=block_table_tensor,
            seq_lens=seq_lens,
            tile_scheduler_metadata=tile_scheduler_metadata,
            num_splits=num_splits,
        )

cg_buf_num_splits instance-attribute

cg_buf_num_splits = None

cg_buf_tile_scheduler_metadata instance-attribute

cg_buf_tile_scheduler_metadata = None

full_cudagraph_supported class-attribute

full_cudagraph_supported: bool = True

num_q_heads instance-attribute

num_q_heads = get_num_attention_heads(parallel_config)

__init__

__init__(
    runner,
    kv_cache_spec: AttentionSpec,
    block_table: BlockTable,
)
Source code in vllm/v1/attention/backends/mla/flashmla.py
def __init__(self, runner, kv_cache_spec: AttentionSpec,
             block_table: BlockTable):
    super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)

    self.num_q_heads = self.runner.model_config.get_num_attention_heads(
        self.runner.parallel_config)

    self.cg_buf_tile_scheduler_metadata = None
    self.cg_buf_num_splits = None

_build_decode

_build_decode(
    block_table_tensor: Tensor, seq_lens: Tensor
) -> FlashMLADecodeMetadata
Source code in vllm/v1/attention/backends/mla/flashmla.py
def _build_decode(self, block_table_tensor: torch.Tensor,
                  seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
    tile_scheduler_metadata, num_splits = \
        get_mla_metadata(
        seq_lens,
        self.num_q_heads,
        1, # MQA for the decode path
    )

    if self.runner.full_cuda_graph:
        # First time around (CUDAGraph capture), allocate the static buffer
        if self.cg_buf_tile_scheduler_metadata is None:
            self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
            self.cg_buf_num_splits = num_splits
        else:
            assert self.cg_buf_num_splits is not None

            # Metadata per-SM, fixed size (#SMs, TileMetadataSize)
            assert (self.cg_buf_tile_scheduler_metadata.size() ==
                    tile_scheduler_metadata.size())
            self.cg_buf_tile_scheduler_metadata.\
                copy_(tile_scheduler_metadata)
            tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata

            # Num splits is per-batch, varying size (batch_size,)
            n = num_splits.size(0)
            # make sure static buffer is large enough
            assert n <= self.cg_buf_num_splits.size(0)
            num_splits_view = self.cg_buf_num_splits[:n]
            num_splits_view.copy_(num_splits)
            self.cg_buf_num_splits[n:].fill_(0)  # fill the rest with 0s
            num_splits = num_splits_view

    return FlashMLADecodeMetadata(
        block_table=block_table_tensor,
        seq_lens=seq_lens,
        tile_scheduler_metadata=tile_scheduler_metadata,
        num_splits=num_splits,
    )