vllm.attention.ops.flashmla
flash_mla_with_kvcache
¶
flash_mla_with_kvcache(
q: Tensor,
k_cache: Tensor,
block_table: Tensor,
cache_seqlens: Tensor,
head_dim_v: int,
tile_scheduler_metadata: Tensor,
num_splits: Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[Tensor, Tensor]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q
|
Tensor
|
(batch_size, seq_len_q, num_heads_q, head_dim). |
required |
k_cache
|
Tensor
|
(num_blocks, page_block_size, num_heads_k, head_dim). |
required |
block_table
|
Tensor
|
(batch_size, max_num_blocks_per_seq), torch.int32. |
required |
cache_seqlens
|
Tensor
|
(batch_size), torch.int32. |
required |
head_dim_v
|
int
|
Head_dim of v. |
required |
tile_scheduler_metadata
|
Tensor
|
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. |
required |
num_splits
|
Tensor
|
(batch_size + 1), torch.int32, return by get_mla_metadata. |
required |
softmax_scale
|
Optional[float]
|
float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). |
None
|
causal
|
bool
|
bool. Whether to apply causal attention mask. |
False
|
Return
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
Source code in vllm/attention/ops/flashmla.py
get_mla_metadata
¶
get_mla_metadata(
cache_seqlens: Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[Tensor, Tensor]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cache_seqlens
|
Tensor
|
(batch_size), dtype torch.int32. |
required |
num_heads_per_head_k
|
int
|
Equals to seq_len_q * num_heads_q // num_heads_k. |
required |
num_heads_k
|
int
|
num_heads_k. |
required |
Return
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32.
Source code in vllm/attention/ops/flashmla.py
is_flashmla_supported
¶
Return: is_supported_flag, unsupported_reason (optional).