Skip to content

vllm.attention.backends.cpu_mla

CPUMLABackend

Bases: AttentionBackend

Source code in vllm/attention/backends/cpu_mla.py
class CPUMLABackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        return "CPU_MLA"

    @staticmethod
    def get_metadata_cls() -> Type["CPUMLAMetadata"]:
        return CPUMLAMetadata

    @staticmethod
    def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
        return CPUMLAMetadataBuilder

    @staticmethod
    def get_state_cls() -> Type["MLACommonState"]:
        return MLACommonState

    @staticmethod
    def get_impl_cls() -> Type["CPUMLAImpl"]:
        return CPUMLAImpl

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
    ) -> Tuple[int, ...]:
        return (num_blocks, block_size, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
    ) -> None:
        ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: torch.Tensor,
    ) -> None:
        ops.copy_blocks_mla(kv_caches, src_to_dists)

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [576]

copy_blocks staticmethod

copy_blocks(
    kv_caches: List[Tensor], src_to_dists: Tensor
) -> None
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def copy_blocks(
    kv_caches: List[torch.Tensor],
    src_to_dists: torch.Tensor,
) -> None:
    ops.copy_blocks_mla(kv_caches, src_to_dists)

get_builder_cls staticmethod

get_builder_cls() -> Type[CPUMLAMetadataBuilder]
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
    return CPUMLAMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> Type[CPUMLAImpl]
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def get_impl_cls() -> Type["CPUMLAImpl"]:
    return CPUMLAImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> Tuple[int, ...]
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,  # assumed to be 1 for MLA
    head_size: int,
) -> Tuple[int, ...]:
    return (num_blocks, block_size, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> Type[CPUMLAMetadata]
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
    return CPUMLAMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def get_name() -> str:
    return "CPU_MLA"

get_state_cls staticmethod

get_state_cls() -> Type[MLACommonState]
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def get_state_cls() -> Type["MLACommonState"]:
    return MLACommonState

get_supported_head_sizes staticmethod

get_supported_head_sizes() -> List[int]
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def get_supported_head_sizes() -> List[int]:
    return [576]

swap_blocks staticmethod

swap_blocks(
    src_kv_cache: Tensor,
    dst_kv_cache: Tensor,
    src_to_dst: Tensor,
) -> None
Source code in vllm/attention/backends/cpu_mla.py
@staticmethod
def swap_blocks(
    src_kv_cache: torch.Tensor,
    dst_kv_cache: torch.Tensor,
    src_to_dst: torch.Tensor,
) -> None:
    ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

CPUMLAImpl

Bases: MLACommonImpl[CPUMLAMetadata]

Source code in vllm/attention/backends/cpu_mla.py
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):

    def __init__(
            self,
            num_heads: int,
            head_size: int,
            scale: float,
            num_kv_heads: int,
            alibi_slopes: Optional[List[float]],
            sliding_window: Optional[int],
            kv_cache_dtype: str,
            blocksparse_params: Optional[Dict[str, Any]],
            logits_soft_cap: Optional[float],
            attn_type: str,
            kv_sharing_target_layer_name: Optional[str],
            # MLA Specific Arguments
            **mla_args) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads,
                         alibi_slopes, sliding_window, kv_cache_dtype,
                         blocksparse_params, logits_soft_cap, attn_type,
                         kv_sharing_target_layer_name, **mla_args)

        unsupported_features = [
            alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
        ]
        if any(unsupported_features):
            raise NotImplementedError(
                "CPUMLAImpl does not support one of the following: "
                "alibi_slopes, sliding_window, blocksparse_params, "
                "logits_soft_cap")

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "CPUMLAImpl")

        # states is implemented.
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "CPUMLAImpl with FP8 KV cache not yet supported")

    def _forward_prefill(
            self,
            q: torch.Tensor,
            kv_c_normed: torch.Tensor,
            k_pe: torch.Tensor,
            kv_c_and_k_pe_cache: torch.Tensor,
            attn_metadata: CPUMLAMetadata,  # type: ignore[override]
    ) -> torch.Tensor:

        prefill_metadata = attn_metadata.prefill_metadata
        assert prefill_metadata is not None

        kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
            -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv_nope\
            .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

        k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

        # For MLA the v head dim is smaller than qk head dim so we pad out
        # v with 0s to match the qk head dim
        v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
                                           value=0)

        output = torch.empty_like(q)
        ipex_ops.varlen_attention(
            query=q,
            key=k,
            value=v_padded,
            out=output,
            seqlen_q=prefill_metadata.prefill_query_start_loc,
            seqlen_k=prefill_metadata.prefill_query_start_loc,
            max_seqlen_q=prefill_metadata.max_query_len,
            max_seqlen_k=prefill_metadata.max_query_len,
            pdropout=0.0,
            softmax_scale=self.scale,
            zero_tensors=False,
            is_causal=True,
            return_softmax=False,
            gen_=None,
            logits_soft_cap=0.0,
            window_size_left=-1,
            window_size_right=-1,
            alibi_slopes=None,
        )

        # remove padding
        output = output.view(-1, self.num_heads,
                             q.shape[-1])[..., :v.shape[-1]]
        return output.reshape(-1, self.num_heads * v.shape[-1])

    def _forward_decode(
            self,
            q_nope: torch.Tensor,
            q_pe: torch.Tensor,
            kv_c_and_k_pe_cache: torch.Tensor,
            attn_metadata: CPUMLAMetadata,  # type: ignore[override]
    ) -> torch.Tensor:
        assert kv_c_and_k_pe_cache.numel() > 0

        decode_meta = attn_metadata.decode_metadata
        assert decode_meta is not None

        q = torch.cat([q_nope, q_pe], dim=-1)
        o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)

        # Run MQA
        ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
                                   decode_meta.block_tables,
                                   decode_meta.seq_lens_tensor)
        return self._v_up_proj(o)

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[List[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[Dict[str, Any]],
    logits_soft_cap: Optional[float],
    attn_type: str,
    kv_sharing_target_layer_name: Optional[str],
    **mla_args,
) -> None
Source code in vllm/attention/backends/cpu_mla.py
def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]],
        logits_soft_cap: Optional[float],
        attn_type: str,
        kv_sharing_target_layer_name: Optional[str],
        # MLA Specific Arguments
        **mla_args) -> None:
    super().__init__(num_heads, head_size, scale, num_kv_heads,
                     alibi_slopes, sliding_window, kv_cache_dtype,
                     blocksparse_params, logits_soft_cap, attn_type,
                     kv_sharing_target_layer_name, **mla_args)

    unsupported_features = [
        alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
    ]
    if any(unsupported_features):
        raise NotImplementedError(
            "CPUMLAImpl does not support one of the following: "
            "alibi_slopes, sliding_window, blocksparse_params, "
            "logits_soft_cap")

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "CPUMLAImpl")

    # states is implemented.
    if is_quantized_kv_cache(self.kv_cache_dtype):
        raise NotImplementedError(
            "CPUMLAImpl with FP8 KV cache not yet supported")

_forward_decode

_forward_decode(
    q_nope: Tensor,
    q_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: CPUMLAMetadata,
) -> Tensor
Source code in vllm/attention/backends/cpu_mla.py
def _forward_decode(
        self,
        q_nope: torch.Tensor,
        q_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: CPUMLAMetadata,  # type: ignore[override]
) -> torch.Tensor:
    assert kv_c_and_k_pe_cache.numel() > 0

    decode_meta = attn_metadata.decode_metadata
    assert decode_meta is not None

    q = torch.cat([q_nope, q_pe], dim=-1)
    o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)

    # Run MQA
    ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
                               decode_meta.block_tables,
                               decode_meta.seq_lens_tensor)
    return self._v_up_proj(o)

_forward_prefill

_forward_prefill(
    q: Tensor,
    kv_c_normed: Tensor,
    k_pe: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    attn_metadata: CPUMLAMetadata,
) -> Tensor
Source code in vllm/attention/backends/cpu_mla.py
def _forward_prefill(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: CPUMLAMetadata,  # type: ignore[override]
) -> torch.Tensor:

    prefill_metadata = attn_metadata.prefill_metadata
    assert prefill_metadata is not None

    kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
        -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
    k_nope, v = kv_nope\
        .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

    k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

    # For MLA the v head dim is smaller than qk head dim so we pad out
    # v with 0s to match the qk head dim
    v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
                                       value=0)

    output = torch.empty_like(q)
    ipex_ops.varlen_attention(
        query=q,
        key=k,
        value=v_padded,
        out=output,
        seqlen_q=prefill_metadata.prefill_query_start_loc,
        seqlen_k=prefill_metadata.prefill_query_start_loc,
        max_seqlen_q=prefill_metadata.max_query_len,
        max_seqlen_k=prefill_metadata.max_query_len,
        pdropout=0.0,
        softmax_scale=self.scale,
        zero_tensors=False,
        is_causal=True,
        return_softmax=False,
        gen_=None,
        logits_soft_cap=0.0,
        window_size_left=-1,
        window_size_right=-1,
        alibi_slopes=None,
    )

    # remove padding
    output = output.view(-1, self.num_heads,
                         q.shape[-1])[..., :v.shape[-1]]
    return output.reshape(-1, self.num_heads * v.shape[-1])

CPUMLAMetadata dataclass

Bases: TorchSDPAMetadata

Source code in vllm/attention/backends/cpu_mla.py
@dataclass
class CPUMLAMetadata(TorchSDPAMetadata):
    # New for MLA
    # Input positions for rotrary embeddings since for MLA the rotary
    # position embeddings are applied inside the attention backend
    input_positions: torch.Tensor = None

    # required by MLACommonImpl
    is_profile_run: bool = False

input_positions class-attribute instance-attribute

input_positions: Tensor = None

is_profile_run class-attribute instance-attribute

is_profile_run: bool = False

__init__

__init__(
    seq_lens_tensor: Optional[Tensor],
    max_decode_seq_len: int,
    block_tables: Optional[Tensor],
    num_prefills: int,
    num_prefill_tokens: int,
    num_decode_tokens: int,
    slot_mapping: Tensor,
    multi_modal_placeholder_index_maps: Optional[
        Dict[str, IndexMap]
    ],
    enable_kv_scales_calculation: bool,
    chunked_prefill: bool,
    seq_lens: Optional[List[int]] = None,
    max_query_len: Optional[int] = None,
    max_kv_len: Optional[int] = None,
    prefill_query_start_loc: Optional[Tensor] = None,
    kv_start_loc: Optional[Tensor] = None,
    prefill_block_tables: Optional[Tensor] = None,
    query_start_loc: Optional[Tensor] = None,
    encoder_seq_lens: Optional[List[int]] = None,
    encoder_seq_lens_tensor: Optional[Tensor] = None,
    max_encoder_seq_len: Optional[int] = None,
    num_encoder_tokens: Optional[int] = None,
    cross_slot_mapping: Optional[Tensor] = None,
    cross_block_tables: Optional[Tensor] = None,
    input_positions: Tensor = None,
    is_profile_run: bool = False,
) -> None

CPUMLAMetadataBuilder

Bases: AttentionMetadataBuilder[CPUMLAMetadata]

Source code in vllm/attention/backends/cpu_mla.py
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):

    def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
        self.chunked_prefill = input_builder.chunked_prefill
        self.input_builder = input_builder
        assert not self.chunked_prefill, \
            "chunked prefill is currently not supported"

    def prepare(self):
        self.input_data = self.input_builder.input_data

    def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
        input_data = self.input_data
        prefill_seq_lens = seq_lens[0:input_data.num_prefills]
        prefill_query_lens = query_lens[0:input_data.num_prefills]
        slot_mapping = torch.tensor(input_data.slot_mapping,
                                    dtype=torch.long,
                                    device="cpu")

        # metadata for prefill
        if input_data.num_prefills > 0:
            query_lens_tensor = torch.tensor(prefill_query_lens,
                                             dtype=torch.int32,
                                             device="cpu")
            kv_lens_tensor = torch.tensor(prefill_seq_lens,
                                          dtype=torch.int32,
                                          device="cpu")
            query_start_loc = torch.zeros(input_data.num_prefills + 1,
                                          dtype=torch.int32,
                                          device="cpu")
            kv_start_loc = torch.zeros(input_data.num_prefills + 1,
                                       dtype=torch.int32,
                                       device="cpu")
            torch.cumsum(query_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=query_start_loc[1:])
            torch.cumsum(kv_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=kv_start_loc[1:])
            max_query_len = max(prefill_query_lens)
            max_kv_len = max(prefill_seq_lens)

            # for chunked-prefill
            if self.chunked_prefill:
                prefill_block_tables = make_tensor_with_pad(
                    self.input_data.prefill_block_tables,
                    pad=0,
                    dtype=torch.int32,
                    device="cpu",
                )
            else:
                prefill_block_tables = None

        else:
            query_start_loc = None
            kv_start_loc = None
            max_query_len = None
            max_kv_len = None
            prefill_block_tables = None

        # metadata for decode
        if input_data.num_decode_tokens != 0:
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[input_data.num_prefills:],
                dtype=torch.int32,
                device="cpu",
            )
            block_tables = make_tensor_with_pad(
                self.input_data.decode_block_tables,
                pad=0,
                dtype=torch.int32,
                device="cpu",
            )
        else:
            block_tables = torch.tensor([])
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[:input_data.num_prefills],
                dtype=torch.int32,
                device="cpu",
            )

        # For multi-modal models
        placeholder_index_maps = None
        if len(input_data.multi_modal_inputs_list) != 0:
            placeholder_index_maps = {
                modality: placeholder_map.index_map()
                for modality, placeholder_map in
                input_data.multi_modal_placeholder_maps.items()
            }

        return CPUMLAMetadata(
            chunked_prefill=self.chunked_prefill,
            seq_lens=prefill_seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
            max_kv_len=max_kv_len,
            prefill_query_start_loc=query_start_loc,
            kv_start_loc=kv_start_loc,
            max_decode_seq_len=input_data.max_decode_seq_len,
            num_prefills=input_data.num_prefills,
            num_prefill_tokens=input_data.num_prefill_tokens,
            num_decode_tokens=input_data.num_decode_tokens,
            block_tables=block_tables,
            prefill_block_tables=prefill_block_tables,
            slot_mapping=slot_mapping,
            multi_modal_placeholder_index_maps=placeholder_index_maps,
            enable_kv_scales_calculation=False,
            input_positions=torch.tensor([self.input_data.input_positions]))

chunked_prefill instance-attribute

chunked_prefill = chunked_prefill

input_builder instance-attribute

input_builder = input_builder

__init__

__init__(input_builder: ModelInputForCPUBuilder) -> None
Source code in vllm/attention/backends/cpu_mla.py
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
    self.chunked_prefill = input_builder.chunked_prefill
    self.input_builder = input_builder
    assert not self.chunked_prefill, \
        "chunked prefill is currently not supported"

build

build(
    seq_lens, query_lens, cuda_graph_pad_size, batch_size
)
Source code in vllm/attention/backends/cpu_mla.py
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
    input_data = self.input_data
    prefill_seq_lens = seq_lens[0:input_data.num_prefills]
    prefill_query_lens = query_lens[0:input_data.num_prefills]
    slot_mapping = torch.tensor(input_data.slot_mapping,
                                dtype=torch.long,
                                device="cpu")

    # metadata for prefill
    if input_data.num_prefills > 0:
        query_lens_tensor = torch.tensor(prefill_query_lens,
                                         dtype=torch.int32,
                                         device="cpu")
        kv_lens_tensor = torch.tensor(prefill_seq_lens,
                                      dtype=torch.int32,
                                      device="cpu")
        query_start_loc = torch.zeros(input_data.num_prefills + 1,
                                      dtype=torch.int32,
                                      device="cpu")
        kv_start_loc = torch.zeros(input_data.num_prefills + 1,
                                   dtype=torch.int32,
                                   device="cpu")
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=torch.int32,
                     out=query_start_loc[1:])
        torch.cumsum(kv_lens_tensor,
                     dim=0,
                     dtype=torch.int32,
                     out=kv_start_loc[1:])
        max_query_len = max(prefill_query_lens)
        max_kv_len = max(prefill_seq_lens)

        # for chunked-prefill
        if self.chunked_prefill:
            prefill_block_tables = make_tensor_with_pad(
                self.input_data.prefill_block_tables,
                pad=0,
                dtype=torch.int32,
                device="cpu",
            )
        else:
            prefill_block_tables = None

    else:
        query_start_loc = None
        kv_start_loc = None
        max_query_len = None
        max_kv_len = None
        prefill_block_tables = None

    # metadata for decode
    if input_data.num_decode_tokens != 0:
        seq_lens_tensor = torch.tensor(
            input_data.seq_lens[input_data.num_prefills:],
            dtype=torch.int32,
            device="cpu",
        )
        block_tables = make_tensor_with_pad(
            self.input_data.decode_block_tables,
            pad=0,
            dtype=torch.int32,
            device="cpu",
        )
    else:
        block_tables = torch.tensor([])
        seq_lens_tensor = torch.tensor(
            input_data.seq_lens[:input_data.num_prefills],
            dtype=torch.int32,
            device="cpu",
        )

    # For multi-modal models
    placeholder_index_maps = None
    if len(input_data.multi_modal_inputs_list) != 0:
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            input_data.multi_modal_placeholder_maps.items()
        }

    return CPUMLAMetadata(
        chunked_prefill=self.chunked_prefill,
        seq_lens=prefill_seq_lens,
        seq_lens_tensor=seq_lens_tensor,
        max_query_len=max_query_len,
        max_kv_len=max_kv_len,
        prefill_query_start_loc=query_start_loc,
        kv_start_loc=kv_start_loc,
        max_decode_seq_len=input_data.max_decode_seq_len,
        num_prefills=input_data.num_prefills,
        num_prefill_tokens=input_data.num_prefill_tokens,
        num_decode_tokens=input_data.num_decode_tokens,
        block_tables=block_tables,
        prefill_block_tables=prefill_block_tables,
        slot_mapping=slot_mapping,
        multi_modal_placeholder_index_maps=placeholder_index_maps,
        enable_kv_scales_calculation=False,
        input_positions=torch.tensor([self.input_data.input_positions]))

prepare

prepare()
Source code in vllm/attention/backends/cpu_mla.py
def prepare(self):
    self.input_data = self.input_builder.input_data