Skip to content

vllm.v1.attention.backends.cpu_attn

TorchSDPABackend

Source code in vllm/v1/attention/backends/cpu_attn.py
class TorchSDPABackend:
    accept_output_buffer: bool = False

    @staticmethod
    def get_name() -> str:
        return "TORCH_SDPA_VLLM_V1"

    @staticmethod
    def get_impl_cls() -> type["TorchSDPABackendImpl"]:
        return TorchSDPABackendImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return TorchSDPAMetadata

    @staticmethod
    def get_state_cls() -> type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
    def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
        return TorchSDPAMetadataBuilderV1

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def use_cascade_attention(*args, **kwargs) -> bool:
        return False

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = False

get_builder_cls staticmethod

get_builder_cls() -> type[TorchSDPAMetadataBuilderV1]
Source code in vllm/v1/attention/backends/cpu_attn.py
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
    return TorchSDPAMetadataBuilderV1

get_impl_cls staticmethod

get_impl_cls() -> type[TorchSDPABackendImpl]
Source code in vllm/v1/attention/backends/cpu_attn.py
@staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
    return TorchSDPABackendImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/cpu_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> tuple[int, ...]:
    return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                             num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/cpu_attn.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return TorchSDPAMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/cpu_attn.py
@staticmethod
def get_name() -> str:
    return "TORCH_SDPA_VLLM_V1"

get_state_cls staticmethod

get_state_cls() -> type[CommonAttentionState]
Source code in vllm/v1/attention/backends/cpu_attn.py
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
    return CommonAttentionState

use_cascade_attention staticmethod

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/cpu_attn.py
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
    return False

TorchSDPAMetadataBuilderV1

Bases: AttentionMetadataBuilder[TorchSDPAMetadata]

Source code in vllm/v1/attention/backends/cpu_attn.py
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):

    def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
                 block_table: BlockTable) -> None:
        self.runner = runner
        self.block_table = block_table

        # For reorder
        self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
                                                      dtype=np.int64)
        self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
                                                      dtype=np.int64)
        self.num_prompt_req: int = 0

        self.seq_start_loc_cpu = torch.zeros(
            runner.max_num_reqs + 1,
            dtype=torch.int32,
            device="cpu",
        )
        self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()

    def reorder_batch(self, input_batch: InputBatch,
                      scheduler_output: SchedulerOutput) -> bool:
        prompt_list_idx = 0
        decode_list_idx = 0
        for req_index in range(input_batch.num_reqs):
            if input_batch.num_computed_tokens_cpu[
                    req_index] < input_batch.num_prompt_tokens[req_index]:
                # prompt stage
                self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
                prompt_list_idx += 1
            else:
                # decode stage
                self.reorder_decode_req_index_list[decode_list_idx] = req_index
                decode_list_idx += 1
        assert decode_list_idx + prompt_list_idx == input_batch.num_reqs

        # Update prompt requests number
        self.num_prompt_req = prompt_list_idx

        reorder_req_num = 0
        for req_index in range(decode_list_idx):
            if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
                reorder_req_num += 1
            else:
                break

        if reorder_req_num == 0:
            return False

        reorder_prompt_list = (
            self.reorder_prompt_req_index_list[:prompt_list_idx]
            [-reorder_req_num:])
        reorder_decode_list = (
            self.reorder_decode_req_index_list[:decode_list_idx]
            [:reorder_req_num])
        assert reorder_decode_list.size == reorder_prompt_list.size

        for idx in range(reorder_req_num):
            prompt_req_index = reorder_prompt_list[idx].item()
            decode_req_index = reorder_decode_list[idx].item()
            input_batch.swap_states(prompt_req_index, decode_req_index)

        return True

    def build(self, common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata):
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len

        runner = self.runner
        block_table = self.block_table
        seq_lens_np = runner.seq_lens_np[:num_reqs]
        num_prompt_req = self.num_prompt_req
        max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
        ) if num_prompt_req > 0 else 0
        max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
        ) if num_prompt_req < num_reqs else 0
        self.seq_start_loc_np[0] = 0
        np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
        num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
        num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
        ) - num_prefill_tokens
        slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
        block_table_tensor = block_table.get_device_tensor()
        attn_metadata = TorchSDPAMetadata(
            num_prefills=num_prompt_req,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            slot_mapping=slot_mapping,
            seq_lens_tensor=runner.
            seq_lens_cpu[num_prompt_req:num_reqs],  # decode
            max_decode_seq_len=max_decode_seq_len,  # decode
            block_tables=block_table_tensor[num_prompt_req:num_reqs],  # decode
            chunked_prefill=True,
            max_query_len=max_query_len,
            max_kv_len=max_prefill_seq_len,
            prefill_query_start_loc=runner.
            query_start_loc_cpu[:num_prompt_req + 1],  # prefill
            kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
                                                1],  # prefill
            prefill_block_tables=block_table_tensor[:
                                                    num_prompt_req],  # prefill
            query_start_loc=runner.query_start_loc_cpu[:num_reqs +
                                                       1],  # for logits index
            multi_modal_placeholder_index_maps=None,
            enable_kv_scales_calculation=False,
        )

        return attn_metadata

block_table instance-attribute

block_table = block_table

num_prompt_req instance-attribute

num_prompt_req: int = 0

reorder_decode_req_index_list instance-attribute

reorder_decode_req_index_list = empty(
    max_num_reqs, dtype=int64
)

reorder_prompt_req_index_list instance-attribute

reorder_prompt_req_index_list = empty(
    max_num_reqs, dtype=int64
)

runner instance-attribute

runner = runner

seq_start_loc_cpu instance-attribute

seq_start_loc_cpu = zeros(
    max_num_reqs + 1, dtype=int32, device="cpu"
)

seq_start_loc_np instance-attribute

seq_start_loc_np = numpy()

__init__

__init__(
    runner: CPUModelRunner,
    kv_cache_spec: AttentionSpec,
    block_table: BlockTable,
) -> None
Source code in vllm/v1/attention/backends/cpu_attn.py
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
             block_table: BlockTable) -> None:
    self.runner = runner
    self.block_table = block_table

    # For reorder
    self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
                                                  dtype=np.int64)
    self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
                                                  dtype=np.int64)
    self.num_prompt_req: int = 0

    self.seq_start_loc_cpu = torch.zeros(
        runner.max_num_reqs + 1,
        dtype=torch.int32,
        device="cpu",
    )
    self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
)
Source code in vllm/v1/attention/backends/cpu_attn.py
def build(self, common_prefix_len: int,
          common_attn_metadata: CommonAttentionMetadata):
    num_reqs = common_attn_metadata.num_reqs
    num_actual_tokens = common_attn_metadata.num_actual_tokens
    max_query_len = common_attn_metadata.max_query_len

    runner = self.runner
    block_table = self.block_table
    seq_lens_np = runner.seq_lens_np[:num_reqs]
    num_prompt_req = self.num_prompt_req
    max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
    ) if num_prompt_req > 0 else 0
    max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
    ) if num_prompt_req < num_reqs else 0
    self.seq_start_loc_np[0] = 0
    np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
    num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
    num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
    ) - num_prefill_tokens
    slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
    block_table_tensor = block_table.get_device_tensor()
    attn_metadata = TorchSDPAMetadata(
        num_prefills=num_prompt_req,
        num_prefill_tokens=num_prefill_tokens,
        num_decode_tokens=num_decode_tokens,
        slot_mapping=slot_mapping,
        seq_lens_tensor=runner.
        seq_lens_cpu[num_prompt_req:num_reqs],  # decode
        max_decode_seq_len=max_decode_seq_len,  # decode
        block_tables=block_table_tensor[num_prompt_req:num_reqs],  # decode
        chunked_prefill=True,
        max_query_len=max_query_len,
        max_kv_len=max_prefill_seq_len,
        prefill_query_start_loc=runner.
        query_start_loc_cpu[:num_prompt_req + 1],  # prefill
        kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
                                            1],  # prefill
        prefill_block_tables=block_table_tensor[:
                                                num_prompt_req],  # prefill
        query_start_loc=runner.query_start_loc_cpu[:num_reqs +
                                                   1],  # for logits index
        multi_modal_placeholder_index_maps=None,
        enable_kv_scales_calculation=False,
    )

    return attn_metadata

reorder_batch

reorder_batch(
    input_batch: InputBatch,
    scheduler_output: SchedulerOutput,
) -> bool
Source code in vllm/v1/attention/backends/cpu_attn.py
def reorder_batch(self, input_batch: InputBatch,
                  scheduler_output: SchedulerOutput) -> bool:
    prompt_list_idx = 0
    decode_list_idx = 0
    for req_index in range(input_batch.num_reqs):
        if input_batch.num_computed_tokens_cpu[
                req_index] < input_batch.num_prompt_tokens[req_index]:
            # prompt stage
            self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
            prompt_list_idx += 1
        else:
            # decode stage
            self.reorder_decode_req_index_list[decode_list_idx] = req_index
            decode_list_idx += 1
    assert decode_list_idx + prompt_list_idx == input_batch.num_reqs

    # Update prompt requests number
    self.num_prompt_req = prompt_list_idx

    reorder_req_num = 0
    for req_index in range(decode_list_idx):
        if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
            reorder_req_num += 1
        else:
            break

    if reorder_req_num == 0:
        return False

    reorder_prompt_list = (
        self.reorder_prompt_req_index_list[:prompt_list_idx]
        [-reorder_req_num:])
    reorder_decode_list = (
        self.reorder_decode_req_index_list[:decode_list_idx]
        [:reorder_req_num])
    assert reorder_decode_list.size == reorder_prompt_list.size

    for idx in range(reorder_req_num):
        prompt_req_index = reorder_prompt_list[idx].item()
        decode_req_index = reorder_decode_list[idx].item()
        input_batch.swap_states(prompt_req_index, decode_req_index)

    return True