Skip to content

vllm.attention.backends.rocm_aiter_mla

AiterMLABackend

Bases: MLACommonBackend

Source code in vllm/attention/backends/rocm_aiter_mla.py
class AiterMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "ROCM_AITER_MLA"

    @staticmethod
    def get_impl_cls() -> Type["AiterMLAImpl"]:
        return AiterMLAImpl

    @staticmethod
    def get_metadata_cls() -> Type["AiterMLAMetadata"]:
        return AiterMLAMetadata

    @staticmethod
    def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
        return AiterMLAMetadataBuilder

    @staticmethod
    def get_state_cls() -> Type["AiterMLAState"]:
        return AiterMLAState

get_builder_cls staticmethod

get_builder_cls() -> Type[AiterMLAMetadataBuilder]
Source code in vllm/attention/backends/rocm_aiter_mla.py
@staticmethod
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
    return AiterMLAMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> Type[AiterMLAImpl]
Source code in vllm/attention/backends/rocm_aiter_mla.py
@staticmethod
def get_impl_cls() -> Type["AiterMLAImpl"]:
    return AiterMLAImpl

get_metadata_cls staticmethod

get_metadata_cls() -> Type[AiterMLAMetadata]
Source code in vllm/attention/backends/rocm_aiter_mla.py
@staticmethod
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
    return AiterMLAMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/attention/backends/rocm_aiter_mla.py
@staticmethod
def get_name() -> str:
    return "ROCM_AITER_MLA"

get_state_cls staticmethod

get_state_cls() -> Type[AiterMLAState]
Source code in vllm/attention/backends/rocm_aiter_mla.py
@staticmethod
def get_state_cls() -> Type["AiterMLAState"]:
    return AiterMLAState

AiterMLAImpl

Bases: MLACommonImpl[AiterMLAMetadata]

Source code in vllm/attention/backends/rocm_aiter_mla.py
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[list[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            blocksparse_params: Optional[dict[str, Any]],
            logits_soft_cap: Optional[float],
            attn_type: str,
            kv_sharing_target_layer_name: Optional[str],
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
                         kv_sharing_target_layer_name, **mla_args)

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "Aiter MLA does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        from aiter import flash_attn_varlen_func
        self.flash_attn_varlen_func = flash_attn_varlen_func

    def _flash_attn_varlen_diff_headdims(
            self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
            softmax_scale: float, return_softmax_lse: bool,
            **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
        output = self.flash_attn_varlen_func(
            q,
            k,
            v,
            **kwargs,
        )

        return output

    def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: AiterMLAMetadata,
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0

        decode_meta = attn_metadata.decode_metadata
        assert decode_meta is not None
        B = q_nope.shape[0]

        q = torch.cat([q_nope, q_pe], dim=-1)
        o = torch.empty(B,
                        self.num_heads,
                        self.kv_lora_rank,
                        dtype=q.dtype,
                        device=q.device)

        kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

        aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
                             attn_metadata.qo_indptr,
                             attn_metadata.max_query_len,
                             attn_metadata.paged_kv_indptr,
                             attn_metadata.paged_kv_indices,
                             attn_metadata.paged_kv_last_page_lens)

        return self._v_up_proj(o)

flash_attn_varlen_func instance-attribute

flash_attn_varlen_func = flash_attn_varlen_func

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[list[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[dict[str, Any]],
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **mla_args,
) -> None
Source code in vllm/attention/backends/rocm_aiter_mla.py
def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[list[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[dict[str, Any]],
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        **mla_args) -> None:
    super().__init__(num_heads, head_size, scale, num_kv_heads,
                     alibi_slopes, sliding_window, kv_cache_dtype,
                     blocksparse_params, logits_soft_cap, attn_type,
                     kv_sharing_target_layer_name, **mla_args)

    unsupported_features = [
        alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
    ]
    if any(unsupported_features):
        raise NotImplementedError(
            "Aiter MLA does not support one of the following: "
            "alibi_slopes, sliding_window, blocksparse_params, "
            "logits_soft_cap")

    from aiter import flash_attn_varlen_func
    self.flash_attn_varlen_func = flash_attn_varlen_func

_flash_attn_varlen_diff_headdims

_flash_attn_varlen_diff_headdims(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    softmax_scale: float,
    return_softmax_lse: bool,
    **kwargs,
) -> Union[tuple[Tensor, ...], Tensor]
Source code in vllm/attention/backends/rocm_aiter_mla.py
def _flash_attn_varlen_diff_headdims(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
        softmax_scale: float, return_softmax_lse: bool,
        **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
    output = self.flash_attn_varlen_func(
        q,
        k,
        v,
        **kwargs,
    )

    return output

_forward_decode

_forward_decode(
    q_nope: Tensor,
    q_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: AiterMLAMetadata,
) -> Tensor
Source code in vllm/attention/backends/rocm_aiter_mla.py
def _forward_decode(
    self,
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
    assert kv_c_and_k_pe_cache.numel() > 0

    decode_meta = attn_metadata.decode_metadata
    assert decode_meta is not None
    B = q_nope.shape[0]

    q = torch.cat([q_nope, q_pe], dim=-1)
    o = torch.empty(B,
                    self.num_heads,
                    self.kv_lora_rank,
                    dtype=q.dtype,
                    device=q.device)

    kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

    aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
                         attn_metadata.qo_indptr,
                         attn_metadata.max_query_len,
                         attn_metadata.paged_kv_indptr,
                         attn_metadata.paged_kv_indices,
                         attn_metadata.paged_kv_last_page_lens)

    return self._v_up_proj(o)

AiterMLAMetadata dataclass

Bases: MLACommonMetadata

Source code in vllm/attention/backends/rocm_aiter_mla.py
@dataclass
class AiterMLAMetadata(MLACommonMetadata):
    # The following 5 tensors are for current version of AITER MLA
    block_table_bound: Optional[torch.Tensor] = None
    # The indptr of the paged kv cache, shape: [batch_size + 1]
    paged_kv_indptr: Optional[torch.Tensor] = None
    # The page indices of the paged kv cache
    paged_kv_indices: Optional[torch.Tensor] = None
    # The number of entries in the last page of each request in
    # the paged kv cache, shape: [batch_size]
    paged_kv_last_page_lens: Optional[torch.Tensor] = None

    # This is just to make new AITER MLA API work
    # -- MTP support is not added yet.
    qo_indptr: Optional[torch.Tensor] = None

    @property
    def prefill_metadata(self):
        prefill_metadata = super().prefill_metadata
        self._cached_prefill_metadata = prefill_metadata

        if prefill_metadata is not None:
            prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
            prefill_metadata.paged_kv_indices = self.paged_kv_indices
            prefill_metadata\
                .paged_kv_last_page_lens = self.paged_kv_last_page_lens
            prefill_metadata.block_table_bound = self.block_table_bound
            prefill_metadata.qo_indptr = self.qo_indptr

            # update the cache
            self._cached_prefill_metadata = self.__class__(
                **prefill_metadata.__dict__)

        return self._cached_prefill_metadata

    @property
    def decode_metadata(self):
        decode_metadata = super().decode_metadata

        self._cached_decode_metadata = decode_metadata

        if decode_metadata is not None:
            decode_metadata.paged_kv_indptr = self.paged_kv_indptr
            decode_metadata.paged_kv_indices = self.paged_kv_indices
            decode_metadata\
                .paged_kv_last_page_lens = self.paged_kv_last_page_lens
            decode_metadata.block_table_bound = self.block_table_bound
            decode_metadata.qo_indptr = self.qo_indptr

            # update the cache
            self._cached_decode_metadata = self.__class__(
                **decode_metadata.__dict__)

        return self._cached_decode_metadata

    def _ops_advance_step(self, num_seqs: int, num_queries: int,
                          block_size: int, input_tokens: torch.Tensor,
                          sampled_token_ids: torch.Tensor,
                          input_positions: torch.Tensor) -> None:

        ops.advance_step_flashinfer(
            num_seqs=num_seqs,
            num_queries=num_queries,
            block_size=block_size,
            input_tokens=input_tokens,
            sampled_token_ids=sampled_token_ids,
            input_positions=input_positions,
            seq_lens=self.seq_lens_tensor,
            slot_mapping=self.slot_mapping,
            block_tables=self.block_tables,
            paged_kv_indices=self.paged_kv_indices,
            paged_kv_indptr=self.paged_kv_indptr,
            paged_kv_last_page_lens=self.paged_kv_last_page_lens,
            block_table_bound=self.block_table_bound)

block_table_bound class-attribute instance-attribute

block_table_bound: Optional[Tensor] = None

decode_metadata property

decode_metadata

paged_kv_indices class-attribute instance-attribute

paged_kv_indices: Optional[Tensor] = None

paged_kv_indptr class-attribute instance-attribute

paged_kv_indptr: Optional[Tensor] = None

paged_kv_last_page_lens class-attribute instance-attribute

paged_kv_last_page_lens: Optional[Tensor] = None

prefill_metadata property

prefill_metadata

qo_indptr class-attribute instance-attribute

qo_indptr: Optional[Tensor] = None

__init__

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decode_tokens: int,
    slot_mapping: Tensor,
    multi_modal_placeholder_index_maps: Optional[
        Dict[str, IndexMap]
    ],
    enable_kv_scales_calculation: bool,
    use_cuda_graph: bool,
    seq_lens: Optional[List[int]],
    seq_lens_tensor: Optional[Tensor],
    max_prefill_seq_len: int,
    max_decode_seq_len: int,
    context_lens_tensor: Optional[Tensor],
    block_tables: Optional[Tensor],
    max_query_len: Optional[int] = None,
    max_decode_query_len: Optional[int] = None,
    query_start_loc: Optional[Tensor] = None,
    seq_start_loc: Optional[Tensor] = None,
    _cached_prefill_metadata: Optional[Any] = None,
    _cached_decode_metadata: Optional[Any] = None,
    head_dim: Optional[int] = None,
    is_profile_run: bool = False,
    context_chunk_cu_seq_lens: Optional[Tensor] = None,
    context_chunk_starts: Optional[Tensor] = None,
    context_chunk_seq_tot: Optional[List[int]] = None,
    context_chunk_max_seq_lens: Optional[List[int]] = None,
    context_chunk_workspace: Optional[Tensor] = None,
    block_table_bound: Optional[Tensor] = None,
    paged_kv_indptr: Optional[Tensor] = None,
    paged_kv_indices: Optional[Tensor] = None,
    paged_kv_last_page_lens: Optional[Tensor] = None,
    qo_indptr: Optional[Tensor] = None,
) -> None

_ops_advance_step

_ops_advance_step(
    num_seqs: int,
    num_queries: int,
    block_size: int,
    input_tokens: Tensor,
    sampled_token_ids: Tensor,
    input_positions: Tensor,
) -> None
Source code in vllm/attention/backends/rocm_aiter_mla.py
def _ops_advance_step(self, num_seqs: int, num_queries: int,
                      block_size: int, input_tokens: torch.Tensor,
                      sampled_token_ids: torch.Tensor,
                      input_positions: torch.Tensor) -> None:

    ops.advance_step_flashinfer(
        num_seqs=num_seqs,
        num_queries=num_queries,
        block_size=block_size,
        input_tokens=input_tokens,
        sampled_token_ids=sampled_token_ids,
        input_positions=input_positions,
        seq_lens=self.seq_lens_tensor,
        slot_mapping=self.slot_mapping,
        block_tables=self.block_tables,
        paged_kv_indices=self.paged_kv_indices,
        paged_kv_indptr=self.paged_kv_indptr,
        paged_kv_last_page_lens=self.paged_kv_last_page_lens,
        block_table_bound=self.block_table_bound)

AiterMLAMetadataBuilder

Bases: MLACommonMetadataBuilder[AiterMLAMetadata]

Source code in vllm/attention/backends/rocm_aiter_mla.py
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
    BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
        super().__init__(input_builder)
        assert self.block_size == 1, "AITER MLA requires only block size 1."

    def prepare(self):
        super().prepare()
        self.paged_kv_indices: list[int] = []
        self.paged_kv_indptr: list[int] = [0]
        self.paged_kv_last_page_lens: list[int] = []
        self.total_blocks = 0
        self.qo_indptr: list[int] = [0]

    def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
                       prefix_cache_hit: bool):
        """Add a sequence group to the metadata. Specifically update/append
        1. context length.
        2. block table.
        3. slot mapping.
        """
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
            self.context_lens.append(context_len)
            if is_prompt:
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
            if prefix_cache_hit:
                # NOTE(woosuk): For flash-attn, the block table should
                # include the entries for the incoming prefill tokens.
                block_table = block_tables[seq_id]
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
                if curr_sliding_window_block == 0:
                    block_table = block_tables[seq_id]
                else:
                    block_table = block_tables[seq_id][
                        -curr_sliding_window_block:]
            self.block_tables.append(block_table)

            # Compute slot mapping.
            is_profile_run = is_block_tables_empty(block_tables)
            start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                       context_len,
                                                       self.sliding_window)
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
                                 self.block_size, inter_data.block_tables)
            if is_profile_run:
                return

            # Update paged_kv_* tensors only for non-profile run
            block_table = block_tables[seq_id]
            self._update_paged_kv_tensors(block_table, seq_len)

    def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
        # Get the number of valid blocks based on sequence length.
        # If seq_len = 16, block_size = 16,
        # block_table_bound is 1 with 1 valid block.
        # If seq_len = 15, block_size = 16,
        # block_table_bound is 0 + 1 with 1 valid block.
        self.total_blocks += len(block_table)
        block_table_bound = seq_len // self.block_size + 1 \
            if seq_len % self.block_size != 0 \
            else seq_len // self.block_size
        self.paged_kv_indices.extend(block_table[:block_table_bound])
        self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
                                    block_table_bound)
        self.qo_indptr.append(self.qo_indptr[-1] + 1)

        last_page_len = seq_len % self.block_size
        if last_page_len == 0:
            last_page_len = self.block_size
        self.paged_kv_last_page_lens.append(last_page_len)

    def build(self, seq_lens: list[int], query_lens: list[int],
              cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
        metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
                                 batch_size)
        device = self.runner.device
        use_captured_graph = cuda_graph_pad_size != -1

        if use_captured_graph:
            last_paged_kv_indptr = self.paged_kv_indptr[-1]
            self.paged_kv_indptr.extend([last_paged_kv_indptr] *
                                        cuda_graph_pad_size)
            self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
            last_qo_indptr = self.qo_indptr[-1]
            self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)

        # For current version of AITER MLA
        if len(self.paged_kv_indptr) > 0:
            # extend to the maximum number of blocks as returned by the
            # scheduler
            self.paged_kv_indices.extend(
                [0] * (self.total_blocks - len(self.paged_kv_indices)))
            paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
                                                   device=device,
                                                   dtype=torch.int)
            paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
                                                  device=device,
                                                  dtype=torch.int)
            paged_kv_last_page_lens_tensor = torch.tensor(
                self.paged_kv_last_page_lens, device=device, dtype=torch.int)
            block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
                                                   1,
                                                   device=device,
                                                   dtype=torch.int)

            qo_indptr = torch.tensor(self.qo_indptr,
                                     device=device,
                                     dtype=torch.int)
        else:
            paged_kv_indices_tensor = None
            paged_kv_indptr_tensor = None
            paged_kv_last_page_lens_tensor = None
            block_table_bound_tensor = None
            qo_indptr = None

        metadata.paged_kv_indptr = paged_kv_indptr_tensor
        metadata.paged_kv_indices = paged_kv_indices_tensor
        metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
        metadata.block_table_bound = block_table_bound_tensor
        metadata.qo_indptr = qo_indptr

        return metadata

BLOCK_TABLE_EXTENDER class-attribute instance-attribute

BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]

__init__

__init__(input_builder: ModelInputForGPUBuilder)
Source code in vllm/attention/backends/rocm_aiter_mla.py
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
    super().__init__(input_builder)
    assert self.block_size == 1, "AITER MLA requires only block size 1."

_add_seq_group

_add_seq_group(
    inter_data,
    chunked_prefill_enabled: bool,
    prefix_cache_hit: bool,
)

Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. 3. slot mapping.

Source code in vllm/attention/backends/rocm_aiter_mla.py
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
                   prefix_cache_hit: bool):
    """Add a sequence group to the metadata. Specifically update/append
    1. context length.
    2. block table.
    3. slot mapping.
    """
    is_prompt = inter_data.is_prompt
    block_tables = inter_data.block_tables

    for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
         curr_sliding_window_block) in zip(
             inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
             inter_data.orig_seq_lens, inter_data.seq_lens,
             inter_data.query_lens, inter_data.context_lens,
             inter_data.curr_sliding_window_blocks):
        self.context_lens.append(context_len)
        if is_prompt:
            self.num_prefills += 1
            self.num_prefill_tokens += token_len
            self.prefill_seq_lens.append(seq_len)
        else:
            self.num_decode_tokens += query_len
            self.curr_seq_lens.append(curr_seq_len)

        # Compute block table.
        # TODO(sang): Combine chunked prefill and prefix caching by
        # only allowing multiple of block_size chunk size.
        # NOTE: This only works for oooooooxxx style attention.
        block_table = []
        if prefix_cache_hit:
            # NOTE(woosuk): For flash-attn, the block table should
            # include the entries for the incoming prefill tokens.
            block_table = block_tables[seq_id]
        elif ((chunked_prefill_enabled or not is_prompt)
              and block_tables is not None):
            if curr_sliding_window_block == 0:
                block_table = block_tables[seq_id]
            else:
                block_table = block_tables[seq_id][
                    -curr_sliding_window_block:]
        self.block_tables.append(block_table)

        # Compute slot mapping.
        is_profile_run = is_block_tables_empty(block_tables)
        start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                   context_len,
                                                   self.sliding_window)
        compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                             seq_len, context_len, start_idx,
                             self.block_size, inter_data.block_tables)
        if is_profile_run:
            return

        # Update paged_kv_* tensors only for non-profile run
        block_table = block_tables[seq_id]
        self._update_paged_kv_tensors(block_table, seq_len)

_update_paged_kv_tensors

_update_paged_kv_tensors(
    block_table: list[int], seq_len: int
)
Source code in vllm/attention/backends/rocm_aiter_mla.py
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
    # Get the number of valid blocks based on sequence length.
    # If seq_len = 16, block_size = 16,
    # block_table_bound is 1 with 1 valid block.
    # If seq_len = 15, block_size = 16,
    # block_table_bound is 0 + 1 with 1 valid block.
    self.total_blocks += len(block_table)
    block_table_bound = seq_len // self.block_size + 1 \
        if seq_len % self.block_size != 0 \
        else seq_len // self.block_size
    self.paged_kv_indices.extend(block_table[:block_table_bound])
    self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
                                block_table_bound)
    self.qo_indptr.append(self.qo_indptr[-1] + 1)

    last_page_len = seq_len % self.block_size
    if last_page_len == 0:
        last_page_len = self.block_size
    self.paged_kv_last_page_lens.append(last_page_len)

build

build(
    seq_lens: list[int],
    query_lens: list[int],
    cuda_graph_pad_size: int,
    batch_size: int,
) -> AiterMLAMetadata
Source code in vllm/attention/backends/rocm_aiter_mla.py
def build(self, seq_lens: list[int], query_lens: list[int],
          cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
    metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
                             batch_size)
    device = self.runner.device
    use_captured_graph = cuda_graph_pad_size != -1

    if use_captured_graph:
        last_paged_kv_indptr = self.paged_kv_indptr[-1]
        self.paged_kv_indptr.extend([last_paged_kv_indptr] *
                                    cuda_graph_pad_size)
        self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
        last_qo_indptr = self.qo_indptr[-1]
        self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)

    # For current version of AITER MLA
    if len(self.paged_kv_indptr) > 0:
        # extend to the maximum number of blocks as returned by the
        # scheduler
        self.paged_kv_indices.extend(
            [0] * (self.total_blocks - len(self.paged_kv_indices)))
        paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
                                               device=device,
                                               dtype=torch.int)
        paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
                                              device=device,
                                              dtype=torch.int)
        paged_kv_last_page_lens_tensor = torch.tensor(
            self.paged_kv_last_page_lens, device=device, dtype=torch.int)
        block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
                                               1,
                                               device=device,
                                               dtype=torch.int)

        qo_indptr = torch.tensor(self.qo_indptr,
                                 device=device,
                                 dtype=torch.int)
    else:
        paged_kv_indices_tensor = None
        paged_kv_indptr_tensor = None
        paged_kv_last_page_lens_tensor = None
        block_table_bound_tensor = None
        qo_indptr = None

    metadata.paged_kv_indptr = paged_kv_indptr_tensor
    metadata.paged_kv_indices = paged_kv_indices_tensor
    metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
    metadata.block_table_bound = block_table_bound_tensor
    metadata.qo_indptr = qo_indptr

    return metadata

prepare

prepare()
Source code in vllm/attention/backends/rocm_aiter_mla.py
def prepare(self):
    super().prepare()
    self.paged_kv_indices: list[int] = []
    self.paged_kv_indptr: list[int] = [0]
    self.paged_kv_last_page_lens: list[int] = []
    self.total_blocks = 0
    self.qo_indptr: list[int] = [0]

AiterMLAState

Bases: MLACommonState[AiterMLAMetadata]

Source code in vllm/attention/backends/rocm_aiter_mla.py
class AiterMLAState(MLACommonState[AiterMLAMetadata]):

    @contextmanager
    def graph_capture(self, max_batch_size: int):
        kv_indices, kv_indptr, last_page_lens, qo_indptr = \
            get_aiter_mla_metadata(
                max_batch_size=max_batch_size,
                block_size=self.runner.block_size,
                max_block_per_batch=\
                    self.runner.get_max_block_per_batch(),
                device=self.runner.device)
        self._paged_kv_indices_tensor = kv_indices
        self._paged_kv_indptr_tensor = kv_indptr
        self._paged_kv_last_page_lens_tensor = last_page_lens
        self._qo_indptr_tensor = qo_indptr

        with super().graph_capture(max_batch_size):
            yield

        del self._paged_kv_indices_tensor
        del self._paged_kv_indptr_tensor
        del self._paged_kv_last_page_lens_tensor
        del self._qo_indptr_tensor

    def graph_capture_get_metadata_for_batch(
            self,
            batch_size: int,
            is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:

        metadata = super().graph_capture_get_metadata_for_batch(
            batch_size, is_encoder_decoder_model)

        paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
        paged_kv_indices = self._paged_kv_indices_tensor
        paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
                                                                       batch_size]
        qo_indptr = self._qo_indptr_tensor[:batch_size + 1]

        metadata.paged_kv_indptr = paged_kv_indptr
        metadata.paged_kv_indices = paged_kv_indices
        metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
        metadata.qo_indptr = qo_indptr

        return metadata

    def get_graph_input_buffers(self,
                                attn_metadata: AiterMLAMetadata,
                                is_encoder_decoder_model: bool = False):
        input_buffers = super().get_graph_input_buffers(
            attn_metadata, is_encoder_decoder_model)
        input_buffers[
            'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
        input_buffers[
            "paged_kv_indices"] = attn_metadata.\
            decode_metadata.paged_kv_indices
        input_buffers[
            "paged_kv_last_page_lens"] = attn_metadata.\
            decode_metadata.paged_kv_last_page_lens
        input_buffers['qo_indptr'] = attn_metadata.qo_indptr

        return input_buffers

    def prepare_graph_input_buffers(self,
                                    input_buffers,
                                    attn_metadata: AiterMLAMetadata,
                                    is_encoder_decoder_model: bool = False):
        super().prepare_graph_input_buffers(input_buffers, attn_metadata,
                                            is_encoder_decoder_model)

        num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
            0]
        input_buffers["paged_kv_indptr"].copy_(
            attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
        input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
            attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
        input_buffers["paged_kv_last_page_lens"].copy_(
            attn_metadata.decode_metadata.paged_kv_last_page_lens,
            non_blocking=True)
        input_buffers["qo_indptr"].copy_(
            attn_metadata.decode_metadata.qo_indptr, non_blocking=True)

get_graph_input_buffers

get_graph_input_buffers(
    attn_metadata: AiterMLAMetadata,
    is_encoder_decoder_model: bool = False,
)
Source code in vllm/attention/backends/rocm_aiter_mla.py
def get_graph_input_buffers(self,
                            attn_metadata: AiterMLAMetadata,
                            is_encoder_decoder_model: bool = False):
    input_buffers = super().get_graph_input_buffers(
        attn_metadata, is_encoder_decoder_model)
    input_buffers[
        'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
    input_buffers[
        "paged_kv_indices"] = attn_metadata.\
        decode_metadata.paged_kv_indices
    input_buffers[
        "paged_kv_last_page_lens"] = attn_metadata.\
        decode_metadata.paged_kv_last_page_lens
    input_buffers['qo_indptr'] = attn_metadata.qo_indptr

    return input_buffers

graph_capture

graph_capture(max_batch_size: int)
Source code in vllm/attention/backends/rocm_aiter_mla.py
@contextmanager
def graph_capture(self, max_batch_size: int):
    kv_indices, kv_indptr, last_page_lens, qo_indptr = \
        get_aiter_mla_metadata(
            max_batch_size=max_batch_size,
            block_size=self.runner.block_size,
            max_block_per_batch=\
                self.runner.get_max_block_per_batch(),
            device=self.runner.device)
    self._paged_kv_indices_tensor = kv_indices
    self._paged_kv_indptr_tensor = kv_indptr
    self._paged_kv_last_page_lens_tensor = last_page_lens
    self._qo_indptr_tensor = qo_indptr

    with super().graph_capture(max_batch_size):
        yield

    del self._paged_kv_indices_tensor
    del self._paged_kv_indptr_tensor
    del self._paged_kv_last_page_lens_tensor
    del self._qo_indptr_tensor

graph_capture_get_metadata_for_batch

graph_capture_get_metadata_for_batch(
    batch_size: int, is_encoder_decoder_model: bool = False
) -> AiterMLAMetadata
Source code in vllm/attention/backends/rocm_aiter_mla.py
def graph_capture_get_metadata_for_batch(
        self,
        batch_size: int,
        is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:

    metadata = super().graph_capture_get_metadata_for_batch(
        batch_size, is_encoder_decoder_model)

    paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
    paged_kv_indices = self._paged_kv_indices_tensor
    paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
                                                                   batch_size]
    qo_indptr = self._qo_indptr_tensor[:batch_size + 1]

    metadata.paged_kv_indptr = paged_kv_indptr
    metadata.paged_kv_indices = paged_kv_indices
    metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
    metadata.qo_indptr = qo_indptr

    return metadata

prepare_graph_input_buffers

prepare_graph_input_buffers(
    input_buffers,
    attn_metadata: AiterMLAMetadata,
    is_encoder_decoder_model: bool = False,
)
Source code in vllm/attention/backends/rocm_aiter_mla.py
def prepare_graph_input_buffers(self,
                                input_buffers,
                                attn_metadata: AiterMLAMetadata,
                                is_encoder_decoder_model: bool = False):
    super().prepare_graph_input_buffers(input_buffers, attn_metadata,
                                        is_encoder_decoder_model)

    num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
        0]
    input_buffers["paged_kv_indptr"].copy_(
        attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
    input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
        attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
    input_buffers["paged_kv_last_page_lens"].copy_(
        attn_metadata.decode_metadata.paged_kv_last_page_lens,
        non_blocking=True)
    input_buffers["qo_indptr"].copy_(
        attn_metadata.decode_metadata.qo_indptr, non_blocking=True)

is_aiter_mla_enabled

is_aiter_mla_enabled() -> bool
Source code in vllm/attention/backends/rocm_aiter_mla.py
def is_aiter_mla_enabled() -> bool:
    return envs.VLLM_ROCM_USE_AITER \
        and envs.VLLM_ROCM_USE_AITER_MLA