Skip to content

vllm.attention.backends.triton_mla

TritonMLABackend

Bases: MLACommonBackend

Source code in vllm/attention/backends/triton_mla.py
class TritonMLABackend(MLACommonBackend):

    @staticmethod
    def get_name() -> str:
        return "TRITON_MLA"

    @staticmethod
    def get_impl_cls() -> Type["TritonMLAImpl"]:
        return TritonMLAImpl

get_impl_cls staticmethod

get_impl_cls() -> Type[TritonMLAImpl]
Source code in vllm/attention/backends/triton_mla.py
@staticmethod
def get_impl_cls() -> Type["TritonMLAImpl"]:
    return TritonMLAImpl

get_name staticmethod

get_name() -> str
Source code in vllm/attention/backends/triton_mla.py
@staticmethod
def get_name() -> str:
    return "TRITON_MLA"

TritonMLAImpl

Bases: MLACommonImpl[MLACommonMetadata]

Source code in vllm/attention/backends/triton_mla.py
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[List[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            blocksparse_params: Optional[Dict[str, Any]],
            logits_soft_cap: Optional[float],
            attn_type: str,
            kv_sharing_target_layer_name: Optional[str],
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
                         kv_sharing_target_layer_name, **mla_args)

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "TritonMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "TritonMLAImpl")

        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "TritonMLA with FP8 KV cache not yet supported")

    def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: MLACommonMetadata,
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0

        decode_meta = attn_metadata.decode_metadata
        assert decode_meta is not None
        B = q_nope.shape[0]

        q = torch.cat([q_nope, q_pe], dim=-1)
        o = torch.zeros(B,
                        self.num_heads,
                        self.kv_lora_rank,
                        dtype=q.dtype,
                        device=q.device)

        num_kv_splits = 4  # TODO: heuristic

        # TODO(lucas) Allocate ahead of time
        attn_logits = torch.empty(
            (
                B,
                self.num_heads,
                num_kv_splits,
                # NOTE(lucas) idk why the +1 is here but sglang has it so we
                # just mirror that
                self.kv_lora_rank + 1,
            ),
            dtype=torch.float32,
            device=q.device,
        )

        # Add a head dim of 1
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
        kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
        PAGE_SIZE = kv_c_and_k_pe_cache.size(1)

        # Run MQA
        decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
                             decode_meta.block_tables,
                             decode_meta.seq_lens_tensor, attn_logits,
                             num_kv_splits, self.scale, PAGE_SIZE)

        return self._v_up_proj(o)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[List[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[Dict[str, Any]],
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **mla_args,
) -> None
Source code in vllm/attention/backends/triton_mla.py
def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]],
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        **mla_args) -> None:
    super().__init__(num_heads, head_size, scale, num_kv_heads,
                     alibi_slopes, sliding_window, kv_cache_dtype,
                     blocksparse_params, logits_soft_cap, attn_type,
                     kv_sharing_target_layer_name, **mla_args)

    unsupported_features = [
        alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
    ]
    if any(unsupported_features):
        raise NotImplementedError(
            "TritonMLAImpl does not support one of the following: "
            "alibi_slopes, sliding_window, blocksparse_params, "
            "logits_soft_cap")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "TritonMLAImpl")

    if is_quantized_kv_cache(self.kv_cache_dtype):
        raise NotImplementedError(
            "TritonMLA with FP8 KV cache not yet supported")

_forward_decode

_forward_decode(
    q_nope: Tensor,
    q_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: MLACommonMetadata,
) -> Tensor
Source code in vllm/attention/backends/triton_mla.py
def _forward_decode(
    self,
    q_nope: torch.Tensor,
    q_pe: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
    assert kv_c_and_k_pe_cache.numel() > 0

    decode_meta = attn_metadata.decode_metadata
    assert decode_meta is not None
    B = q_nope.shape[0]

    q = torch.cat([q_nope, q_pe], dim=-1)
    o = torch.zeros(B,
                    self.num_heads,
                    self.kv_lora_rank,
                    dtype=q.dtype,
                    device=q.device)

    num_kv_splits = 4  # TODO: heuristic

    # TODO(lucas) Allocate ahead of time
    attn_logits = torch.empty(
        (
            B,
            self.num_heads,
            num_kv_splits,
            # NOTE(lucas) idk why the +1 is here but sglang has it so we
            # just mirror that
            self.kv_lora_rank + 1,
        ),
        dtype=torch.float32,
        device=q.device,
    )

    # Add a head dim of 1
    kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
    kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
    PAGE_SIZE = kv_c_and_k_pe_cache.size(1)

    # Run MQA
    decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
                         decode_meta.block_tables,
                         decode_meta.seq_lens_tensor, attn_logits,
                         num_kv_splits, self.scale, PAGE_SIZE)

    return self._v_up_proj(o)