vllm.attention.backends.flashmla
FlashMLABackend
¶
Bases: MLACommonBackend
Source code in vllm/attention/backends/flashmla.py
get_builder_cls
staticmethod
¶
get_builder_cls() -> Type[FlashMLAMetadataBuilder]
get_impl_cls
staticmethod
¶
get_impl_cls() -> Type[FlashMLAImpl]
get_metadata_cls
staticmethod
¶
get_metadata_cls() -> Type[FlashMLAMetadata]
FlashMLAImpl
¶
Bases: MLACommonImpl[FlashMLAMetadata]
Source code in vllm/attention/backends/flashmla.py
__init__
¶
__init__(
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str] = None,
**mla_args,
) -> None
Source code in vllm/attention/backends/flashmla.py
_forward_decode
¶
_forward_decode(
q_nope: Tensor,
q_pe: Tensor,
kv_c_and_k_pe_cache: Tensor,
attn_metadata: FlashMLAMetadata,
) -> Tensor
Source code in vllm/attention/backends/flashmla.py
FlashMLAMetadata
dataclass
¶
Bases: MLACommonMetadata
Source code in vllm/attention/backends/flashmla.py
decode_tile_scheduler_metadata
class-attribute
instance-attribute
¶
__init__
¶
__init__(
num_prefills: int,
num_prefill_tokens: int,
num_decode_tokens: int,
slot_mapping: Tensor,
multi_modal_placeholder_index_maps: Optional[
Dict[str, IndexMap]
],
enable_kv_scales_calculation: bool,
use_cuda_graph: bool,
seq_lens: Optional[List[int]],
seq_lens_tensor: Optional[Tensor],
max_prefill_seq_len: int,
max_decode_seq_len: int,
context_lens_tensor: Optional[Tensor],
block_tables: Optional[Tensor],
max_query_len: Optional[int] = None,
max_decode_query_len: Optional[int] = None,
query_start_loc: Optional[Tensor] = None,
seq_start_loc: Optional[Tensor] = None,
_cached_prefill_metadata: Optional[Any] = None,
_cached_decode_metadata: Optional[Any] = None,
head_dim: Optional[int] = None,
is_profile_run: bool = False,
context_chunk_cu_seq_lens: Optional[Tensor] = None,
context_chunk_starts: Optional[Tensor] = None,
context_chunk_seq_tot: Optional[List[int]] = None,
context_chunk_max_seq_lens: Optional[List[int]] = None,
context_chunk_workspace: Optional[Tensor] = None,
decode_tile_scheduler_metadata: Optional[
Tuple[Tensor, Tensor]
] = None,
decode_num_splits: Optional[Tensor] = None,
) -> None
FlashMLAMetadataBuilder
¶
Bases: MLACommonMetadataBuilder[FlashMLAMetadata]
Source code in vllm/attention/backends/flashmla.py
__init__
¶
build
¶
Source code in vllm/attention/backends/flashmla.py
FlashMLAState
¶
Bases: MLACommonState[FlashMLAMetadata]
Source code in vllm/attention/backends/flashmla.py
__init__
¶
get_graph_input_buffers
¶
get_graph_input_buffers(
attn_metadata, is_encoder_decoder_model: bool = False
)
Source code in vllm/attention/backends/flashmla.py
graph_capture
¶
graph_capture(max_batch_size: int)
Source code in vllm/attention/backends/flashmla.py
graph_capture_get_metadata_for_batch
¶
Source code in vllm/attention/backends/flashmla.py
prepare_graph_input_buffers
¶
prepare_graph_input_buffers(
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False,
)