vllm.attention.ops.rocm_aiter_mla
aiter_mla_decode_fwd
¶
aiter_mla_decode_fwd(
q: Tensor,
kv_buffer: Tensor,
o: Tensor,
sm_scale: float,
qo_indptr: Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[Tensor] = None,
kv_indices: Optional[Tensor] = None,
kv_last_page_lens: Optional[Tensor] = None,
logit_cap: float = 0.0,
)
Source code in vllm/attention/ops/rocm_aiter_mla.py
get_aiter_mla_metadata
¶
get_aiter_mla_metadata(
max_batch_size: int,
block_size: int,
max_block_per_batch: int,
device: device,
) -> tuple[Tensor, ...]
Source code in vllm/attention/ops/rocm_aiter_mla.py
mla_decode_fwd_fake
¶
mla_decode_fwd_fake(
q: Tensor,
kv_buffer: Tensor,
o: Tensor,
qo_indptr: Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[Tensor] = None,
kv_indices: Optional[Tensor] = None,
kv_last_page_lens: Optional[Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None
Source code in vllm/attention/ops/rocm_aiter_mla.py
mla_decode_fwd_impl
¶
mla_decode_fwd_impl(
q: Tensor,
kv_buffer: Tensor,
o: Tensor,
qo_indptr: Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[Tensor] = None,
kv_indices: Optional[Tensor] = None,
kv_last_page_lens: Optional[Tensor] = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None