Skip to content

vllm.attention.ops.rocm_aiter_mla

aiter_mla_decode_fwd

aiter_mla_decode_fwd(
    q: Tensor,
    kv_buffer: Tensor,
    o: Tensor,
    sm_scale: float,
    qo_indptr: Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[Tensor] = None,
    kv_indices: Optional[Tensor] = None,
    kv_last_page_lens: Optional[Tensor] = None,
    logit_cap: float = 0.0,
)
Source code in vllm/attention/ops/rocm_aiter_mla.py
def aiter_mla_decode_fwd(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    sm_scale: float,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    logit_cap: float = 0.0,
):

    torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
                                             kv_buffer.view(
                                                 -1, 1, 1, q.shape[-1]),
                                             o,
                                             qo_indptr,
                                             max_seqlen_qo,
                                             kv_indptr,
                                             kv_indices,
                                             kv_last_page_lens,
                                             sm_scale=sm_scale,
                                             logit_cap=logit_cap)

get_aiter_mla_metadata

get_aiter_mla_metadata(
    max_batch_size: int,
    block_size: int,
    max_block_per_batch: int,
    device: device,
) -> tuple[Tensor, ...]
Source code in vllm/attention/ops/rocm_aiter_mla.py
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
                           max_block_per_batch: int,
                           device: torch.device) -> tuple[torch.Tensor, ...]:
    paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch,
                                   dtype=torch.int32,
                                   device=device)
    paged_kv_indptr = torch.zeros(max_batch_size + 1,
                                  dtype=torch.int32,
                                  device=device)
    paged_kv_last_page_lens = torch.full((max_batch_size, ),
                                         block_size,
                                         dtype=torch.int32)
    qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
    return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr

mla_decode_fwd_fake

mla_decode_fwd_fake(
    q: Tensor,
    kv_buffer: Tensor,
    o: Tensor,
    qo_indptr: Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[Tensor] = None,
    kv_indices: Optional[Tensor] = None,
    kv_last_page_lens: Optional[Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None
Source code in vllm/attention/ops/rocm_aiter_mla.py
def mla_decode_fwd_fake(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
    pass

mla_decode_fwd_impl

mla_decode_fwd_impl(
    q: Tensor,
    kv_buffer: Tensor,
    o: Tensor,
    qo_indptr: Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[Tensor] = None,
    kv_indices: Optional[Tensor] = None,
    kv_last_page_lens: Optional[Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None
Source code in vllm/attention/ops/rocm_aiter_mla.py
def mla_decode_fwd_impl(
    q: torch.Tensor,
    kv_buffer: torch.Tensor,
    o: torch.Tensor,
    qo_indptr: torch.Tensor,
    max_seqlen_qo: int,
    kv_indptr: Optional[torch.Tensor] = None,
    kv_indices: Optional[torch.Tensor] = None,
    kv_last_page_lens: Optional[torch.Tensor] = None,
    sm_scale: float = 1.0,
    logit_cap: float = 0.0,
) -> None:
    from aiter.mla import mla_decode_fwd

    mla_decode_fwd(q,
                   kv_buffer.view(-1, 1, 1, q.shape[-1]),
                   o,
                   qo_indptr,
                   kv_indptr,
                   kv_indices,
                   kv_last_page_lens,
                   max_seqlen_qo,
                   sm_scale=sm_scale,
                   logit_cap=logit_cap)