Skip to content

vllm.attention.ops.flashmla

_flashmla_C_AVAILABLE module-attribute

_flashmla_C_AVAILABLE = True

logger module-attribute

logger = init_logger(__name__)

flash_mla_with_kvcache

flash_mla_with_kvcache(
    q: Tensor,
    k_cache: Tensor,
    block_table: Tensor,
    cache_seqlens: Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: Tensor,
    num_splits: Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[Tensor, Tensor]

Parameters:

Name Type Description Default
q Tensor

(batch_size, seq_len_q, num_heads_q, head_dim).

required
k_cache Tensor

(num_blocks, page_block_size, num_heads_k, head_dim).

required
block_table Tensor

(batch_size, max_num_blocks_per_seq), torch.int32.

required
cache_seqlens Tensor

(batch_size), torch.int32.

required
head_dim_v int

Head_dim of v.

required
tile_scheduler_metadata Tensor

(num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.

required
num_splits Tensor

(batch_size + 1), torch.int32, return by get_mla_metadata.

required
softmax_scale Optional[float]

float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).

None
causal bool

bool. Whether to apply causal attention mask.

False
Return

out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.

Source code in vllm/attention/ops/flashmla.py
def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head_dim of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
                                 torch.int32, return by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
        softmax_scale: float. The scaling of QK^T before applying softmax. 
                       Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1]**(-0.5)
    out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    return out, softmax_lse

get_mla_metadata

get_mla_metadata(
    cache_seqlens: Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[Tensor, Tensor]

Parameters:

Name Type Description Default
cache_seqlens Tensor

(batch_size), dtype torch.int32.

required
num_heads_per_head_k int

Equals to seq_len_q * num_heads_q // num_heads_k.

required
num_heads_k int

num_heads_k.

required
Return

tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32.

Source code in vllm/attention/ops/flashmla.py
def get_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Return:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
                                 dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
                                                  num_heads_per_head_k,
                                                  num_heads_k)

is_flashmla_supported

is_flashmla_supported() -> Tuple[bool, Optional[str]]

Return: is_supported_flag, unsupported_reason (optional).

Source code in vllm/attention/ops/flashmla.py
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
    """
    Return: is_supported_flag, unsupported_reason (optional).
    """
    if not current_platform.is_cuda():
        return False, "FlashMLA is only supported on CUDA devices."
    if current_platform.get_device_capability()[0] != 9:
        return False, "FlashMLA is only supported on Hopper devices."
    if not _flashmla_C_AVAILABLE:
        return False, "vllm._flashmla_C is not available, likely was not "\
            "compiled due to insufficient nvcc version or a supported arch "\
            "(only sm90a currently) was not in the list of target arches to "\
            "compile for."
    return True, None