Skip to content

vllm.attention.backends.torch_sdpa

Attention layer with torch scaled_dot_product_attention and PagedAttention.

logger module-attribute

logger = init_logger(__name__)

TorchSDPABackend

Bases: AttentionBackend

Source code in vllm/attention/backends/torch_sdpa.py
class TorchSDPABackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        return "TORCH_SDPA"

    @staticmethod
    def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
        return TorchSDPABackendImpl

    @staticmethod
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return TorchSDPAMetadata

    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
    def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
        return TorchSDPAMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: torch.Tensor,
    ) -> None:
        raise NotImplementedError("Swap is not supported in TorchSDPABackend.")

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: torch.Tensor,
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)

copy_blocks staticmethod

copy_blocks(
    kv_caches: List[Tensor], src_to_dists: Tensor
) -> None
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def copy_blocks(
    kv_caches: List[torch.Tensor],
    src_to_dists: torch.Tensor,
) -> None:
    PagedAttention.copy_blocks(kv_caches, src_to_dists)

get_builder_cls staticmethod

get_builder_cls() -> Type[TorchSDPAMetadataBuilder]
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
    return TorchSDPAMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> Type[TorchSDPABackendImpl]
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
    return TorchSDPABackendImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> Tuple[int, ...]
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> Tuple[int, ...]:
    return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                             num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> Type[AttentionMetadata]
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
    return TorchSDPAMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def get_name() -> str:
    return "TORCH_SDPA"

get_state_cls staticmethod

get_state_cls() -> Type[CommonAttentionState]
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
    return CommonAttentionState

swap_blocks staticmethod

swap_blocks(
    src_kv_cache: Tensor,
    dst_kv_cache: Tensor,
    src_to_dst: Tensor,
) -> None
Source code in vllm/attention/backends/torch_sdpa.py
@staticmethod
def swap_blocks(
    src_kv_cache: torch.Tensor,
    dst_kv_cache: torch.Tensor,
    src_to_dst: torch.Tensor,
) -> None:
    raise NotImplementedError("Swap is not supported in TorchSDPABackend.")

TorchSDPABackendImpl

Bases: AttentionImpl[TorchSDPAMetadata]

Source code in vllm/attention/backends/torch_sdpa.py
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]] = None,
        logits_soft_cap: Optional[float] = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[str] = None,
        use_irope: bool = False,
    ) -> None:
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported in V0.")
        if blocksparse_params is not None:
            raise ValueError(
                "Torch SPDA does not support block-sparse attention.")
        if logits_soft_cap is not None:
            logger.warning_once("Torch SPDA does not support logits soft cap. "
                                "Outputs may be slightly off.")
        if use_irope:
            logger.warning_once(
                "Using irope in Torch SPDA is not supported yet, it will fall"
                " back to global attention for long context.")
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        self.sliding_window = sliding_window
        self.kv_cache_dtype = kv_cache_dtype

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.need_mask = (self.alibi_slopes is not None
                          or self.sliding_window is not None)

        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {supported_head_sizes}.")

        if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
            raise NotImplementedError(
                "Torch SDPA backend FP8 KV cache requires "
                "intel_extension_for_pytorch support.")
        self.attn_type = attn_type

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: TorchSDPAMetadata,  # type: ignore
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with torch SDPA and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for TorchSDPABackendImpl")

        # For warming-up
        if attn_metadata is None:
            return query

        attn_type = self.attn_type
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None

        if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
            # KV-cache during decoder-self- or
            # encoder-decoder-cross-attention, but not
            # during encoder attention.
            #
            # Even if there are no new key/value pairs to cache,
            # we still need to break out key_cache and value_cache
            # i.e. for later use by paged attention
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            if (key is not None) and (value is not None):
                if attn_type == AttentionType.ENCODER_DECODER:
                    # Update cross-attention KV cache (prefill-only)
                    # During cross-attention decode, key & value will be None,
                    # preventing this IF-statement branch from running
                    updated_slot_mapping = attn_metadata.cross_slot_mapping
                else:
                    # Update self-attention KV cache (prefill/decode)
                    updated_slot_mapping = attn_metadata.slot_mapping

                PagedAttention.write_to_paged_cache(
                    key, value, key_cache, value_cache, updated_slot_mapping,
                    self.kv_cache_dtype, layer._k_scale, layer._v_scale)

        if attn_type != AttentionType.ENCODER:
            # Decoder self-attention supports chunked prefill.
            # Encoder/decoder cross-attention requires no chunked
            # prefill (100% prefill or 100% decode tokens, no mix)
            num_prefill_tokens = attn_metadata.num_prefill_tokens
            num_decode_tokens = attn_metadata.num_decode_tokens
        else:
            # Encoder attention - chunked prefill is not applicable;
            # derive token-count from query shape & and treat them
            # as 100% prefill tokens
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
            num_decode_tokens = 0

        if attn_type == AttentionType.DECODER:
            # Only enforce this shape-constraint for decoder
            # self-attention
            assert key.shape[0] == num_prefill_tokens + num_decode_tokens
            assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        if prefill_meta := attn_metadata.prefill_metadata:
            if not prefill_meta.prefill_metadata.chunked_prefill:  # type: ignore
                assert attn_metadata.seq_lens is not None
                self._run_sdpa_forward(output,
                                       query,
                                       key,
                                       value,
                                       prefill_meta,
                                       attn_type=attn_type)
            else:
                # prefix-enabled attention
                assert not self.need_mask
                import intel_extension_for_pytorch.llm.modules as ipex_modules
                output = torch.empty_like(query)
                ipex_modules.PagedAttention.flash_attn_varlen_func(
                    output[:prefill_meta.num_prefill_tokens, :, :],
                    query[:prefill_meta.num_prefill_tokens, :, :],
                    key_cache,
                    value_cache,
                    prefill_meta.prefill_query_start_loc,
                    prefill_meta.kv_start_loc,
                    prefill_meta.max_query_len,
                    prefill_meta.max_kv_len,
                    self.scale,
                    True,
                    prefill_meta.prefill_block_tables,
                    self.alibi_slopes,
                )

        if decode_meta := attn_metadata.decode_metadata:
            assert attn_type != AttentionType.ENCODER_ONLY, (
                "Encoder-only models should not have decode metadata.")
            # Decoding run.
            (
                seq_lens_arg,
                max_seq_len_arg,
                block_tables_arg,
            ) = decode_meta.get_seq_len_block_table_args(attn_type)

            PagedAttention.forward_decode(
                output[attn_metadata.num_prefill_tokens:, :, :],
                query[attn_metadata.num_prefill_tokens:, :, :],
                key_cache,
                value_cache,
                block_tables_arg,
                seq_lens_arg,
                max_seq_len_arg,
                self.kv_cache_dtype,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
                layer._k_scale,
                layer._v_scale,
            )

        # Reshape the output tensor.
        return output.view(-1, self.num_heads * self.head_size)

    def _run_sdpa_forward(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attn_metadata: TorchSDPAMetadata,
        attn_type: str = AttentionType.DECODER,
    ) -> None:
        if self.num_kv_heads != self.num_heads:
            key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
            value = value.repeat_interleave(self.num_queries_per_kv, dim=1)

        attn_masks = attn_metadata.get_attn_bias(attn_type)
        if attn_masks is None:
            if self.alibi_slopes is not None:
                attn_masks = _make_alibi_bias(
                    self.alibi_slopes, query.dtype,
                    attn_metadata.seq_lens)  # type: ignore
            elif self.sliding_window is not None:
                assert attn_metadata.seq_lens is not None
                attn_masks = _make_sliding_window_bias(
                    attn_metadata.seq_lens, self.sliding_window,
                    query.dtype)  # type: ignore
            else:
                seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
                attn_masks = [None] * len(seq_lens)
            attn_metadata.set_attn_bias(attn_masks, attn_type)

        query = query.movedim(0, query.dim() - 2)
        key = key.movedim(0, key.dim() - 2)
        value = value.movedim(0, value.dim() - 2)

        causal_attn = (attn_type == AttentionType.DECODER)

        seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
        start_q, start_kv = 0, 0
        for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
                                               attn_masks):
            end_q = start_q + seq_len_q
            end_kv = start_kv + seq_len_kv
            sub_out = scaled_dot_product_attention(
                query[None, :, start_q:end_q, :],
                key[None, :, start_kv:end_kv, :],
                value[None, :, start_kv:end_kv, :],
                attn_mask=mask,
                dropout_p=0.0,
                is_causal=causal_attn and mask is None,
                scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
            output[start_q:end_q, :, :] = sub_out
            start_q, start_kv = end_q, end_kv

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

attn_type instance-attribute

attn_type = attn_type

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

need_mask instance-attribute

need_mask = (
    alibi_slopes is not None or sliding_window is not None
)

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = float(scale)

sliding_window instance-attribute

sliding_window = sliding_window

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[List[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[Dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None
Source code in vllm/attention/backends/torch_sdpa.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[List[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[Dict[str, Any]] = None,
    logits_soft_cap: Optional[float] = None,
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None:
    if kv_sharing_target_layer_name is not None:
        raise NotImplementedError("KV sharing is not supported in V0.")
    if blocksparse_params is not None:
        raise ValueError(
            "Torch SPDA does not support block-sparse attention.")
    if logits_soft_cap is not None:
        logger.warning_once("Torch SPDA does not support logits soft cap. "
                            "Outputs may be slightly off.")
    if use_irope:
        logger.warning_once(
            "Using irope in Torch SPDA is not supported yet, it will fall"
            " back to global attention for long context.")
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    self.sliding_window = sliding_window
    self.kv_cache_dtype = kv_cache_dtype

    self.num_queries_per_kv = self.num_heads // self.num_kv_heads
    self.need_mask = (self.alibi_slopes is not None
                      or self.sliding_window is not None)

    supported_head_sizes = PagedAttention.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        raise ValueError(
            f"Head size {head_size} is not supported by PagedAttention. "
            f"Supported head sizes are: {supported_head_sizes}.")

    if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
        raise NotImplementedError(
            "Torch SDPA backend FP8 KV cache requires "
            "intel_extension_for_pytorch support.")
    self.attn_type = attn_type

_run_sdpa_forward

_run_sdpa_forward(
    output: Tensor,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attn_metadata: TorchSDPAMetadata,
    attn_type: str = DECODER,
) -> None
Source code in vllm/attention/backends/torch_sdpa.py
def _run_sdpa_forward(
    self,
    output: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_metadata: TorchSDPAMetadata,
    attn_type: str = AttentionType.DECODER,
) -> None:
    if self.num_kv_heads != self.num_heads:
        key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
        value = value.repeat_interleave(self.num_queries_per_kv, dim=1)

    attn_masks = attn_metadata.get_attn_bias(attn_type)
    if attn_masks is None:
        if self.alibi_slopes is not None:
            attn_masks = _make_alibi_bias(
                self.alibi_slopes, query.dtype,
                attn_metadata.seq_lens)  # type: ignore
        elif self.sliding_window is not None:
            assert attn_metadata.seq_lens is not None
            attn_masks = _make_sliding_window_bias(
                attn_metadata.seq_lens, self.sliding_window,
                query.dtype)  # type: ignore
        else:
            seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
            attn_masks = [None] * len(seq_lens)
        attn_metadata.set_attn_bias(attn_masks, attn_type)

    query = query.movedim(0, query.dim() - 2)
    key = key.movedim(0, key.dim() - 2)
    value = value.movedim(0, value.dim() - 2)

    causal_attn = (attn_type == AttentionType.DECODER)

    seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
    start_q, start_kv = 0, 0
    for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv,
                                           attn_masks):
        end_q = start_q + seq_len_q
        end_kv = start_kv + seq_len_kv
        sub_out = scaled_dot_product_attention(
            query[None, :, start_q:end_q, :],
            key[None, :, start_kv:end_kv, :],
            value[None, :, start_kv:end_kv, :],
            attn_mask=mask,
            dropout_p=0.0,
            is_causal=causal_attn and mask is None,
            scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
        output[start_q:end_q, :, :] = sub_out
        start_q, start_kv = end_q, end_kv

forward

forward(
    layer: AttentionLayer,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: TorchSDPAMetadata,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
) -> Tensor

Forward pass with torch SDPA and PagedAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads * head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
attn_metadata TorchSDPAMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/attention/backends/torch_sdpa.py
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: TorchSDPAMetadata,  # type: ignore
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with torch SDPA and PagedAttention.

    Args:
        query: shape = [num_tokens, num_heads * head_size]
        key: shape = [num_tokens, num_kv_heads * head_size]
        value: shape = [num_tokens, num_kv_heads * head_size]
        kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            NOTE: kv_cache will be an empty tensor with shape [0]
            for profiling run.
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    if output_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported"
            " for TorchSDPABackendImpl")

    # For warming-up
    if attn_metadata is None:
        return query

    attn_type = self.attn_type
    if (attn_type == AttentionType.ENCODER
            and (not attn_metadata.is_all_encoder_attn_metadata_set)):
        raise AttributeError("Encoder attention requires setting "
                             "encoder metadata attributes.")
    elif (attn_type == AttentionType.ENCODER_DECODER
          and (not attn_metadata.is_all_cross_attn_metadata_set)):
        raise AttributeError("Encoder/decoder cross-attention "
                             "requires setting cross-attention "
                             "metadata attributes.")

    # Reshape the query, key, and value tensors.
    query = query.view(-1, self.num_heads, self.head_size)
    if key is not None:
        assert value is not None
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
    else:
        assert value is None

    if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
        # KV-cache during decoder-self- or
        # encoder-decoder-cross-attention, but not
        # during encoder attention.
        #
        # Even if there are no new key/value pairs to cache,
        # we still need to break out key_cache and value_cache
        # i.e. for later use by paged attention
        key_cache, value_cache = PagedAttention.split_kv_cache(
            kv_cache, self.num_kv_heads, self.head_size)

        if (key is not None) and (value is not None):
            if attn_type == AttentionType.ENCODER_DECODER:
                # Update cross-attention KV cache (prefill-only)
                # During cross-attention decode, key & value will be None,
                # preventing this IF-statement branch from running
                updated_slot_mapping = attn_metadata.cross_slot_mapping
            else:
                # Update self-attention KV cache (prefill/decode)
                updated_slot_mapping = attn_metadata.slot_mapping

            PagedAttention.write_to_paged_cache(
                key, value, key_cache, value_cache, updated_slot_mapping,
                self.kv_cache_dtype, layer._k_scale, layer._v_scale)

    if attn_type != AttentionType.ENCODER:
        # Decoder self-attention supports chunked prefill.
        # Encoder/decoder cross-attention requires no chunked
        # prefill (100% prefill or 100% decode tokens, no mix)
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
    else:
        # Encoder attention - chunked prefill is not applicable;
        # derive token-count from query shape & and treat them
        # as 100% prefill tokens
        assert attn_metadata.num_encoder_tokens is not None
        num_prefill_tokens = attn_metadata.num_encoder_tokens
        num_decode_tokens = 0

    if attn_type == AttentionType.DECODER:
        # Only enforce this shape-constraint for decoder
        # self-attention
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

    output = torch.empty_like(query)
    if prefill_meta := attn_metadata.prefill_metadata:
        if not prefill_meta.prefill_metadata.chunked_prefill:  # type: ignore
            assert attn_metadata.seq_lens is not None
            self._run_sdpa_forward(output,
                                   query,
                                   key,
                                   value,
                                   prefill_meta,
                                   attn_type=attn_type)
        else:
            # prefix-enabled attention
            assert not self.need_mask
            import intel_extension_for_pytorch.llm.modules as ipex_modules
            output = torch.empty_like(query)
            ipex_modules.PagedAttention.flash_attn_varlen_func(
                output[:prefill_meta.num_prefill_tokens, :, :],
                query[:prefill_meta.num_prefill_tokens, :, :],
                key_cache,
                value_cache,
                prefill_meta.prefill_query_start_loc,
                prefill_meta.kv_start_loc,
                prefill_meta.max_query_len,
                prefill_meta.max_kv_len,
                self.scale,
                True,
                prefill_meta.prefill_block_tables,
                self.alibi_slopes,
            )

    if decode_meta := attn_metadata.decode_metadata:
        assert attn_type != AttentionType.ENCODER_ONLY, (
            "Encoder-only models should not have decode metadata.")
        # Decoding run.
        (
            seq_lens_arg,
            max_seq_len_arg,
            block_tables_arg,
        ) = decode_meta.get_seq_len_block_table_args(attn_type)

        PagedAttention.forward_decode(
            output[attn_metadata.num_prefill_tokens:, :, :],
            query[attn_metadata.num_prefill_tokens:, :, :],
            key_cache,
            value_cache,
            block_tables_arg,
            seq_lens_arg,
            max_seq_len_arg,
            self.kv_cache_dtype,
            self.num_kv_heads,
            self.scale,
            self.alibi_slopes,
            layer._k_scale,
            layer._v_scale,
        )

    # Reshape the output tensor.
    return output.view(-1, self.num_heads * self.head_size)

TorchSDPAMetadata dataclass

Bases: AttentionMetadata, PagedAttentionMetadata

Metadata for TorchSDPABackend.

Source code in vllm/attention/backends/torch_sdpa.py
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
    """Metadata for TorchSDPABackend.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    chunked_prefill: bool
    seq_lens: Optional[List[int]] = None  # For non-chunked prefill

    # For chunked prefill only
    max_query_len: Optional[int] = None
    max_kv_len: Optional[int] = None
    prefill_query_start_loc: Optional[torch.Tensor] = None
    kv_start_loc: Optional[torch.Tensor] = None
    prefill_block_tables: Optional[torch.Tensor] = None

    # For V1 logits index only
    query_start_loc: Optional[torch.Tensor] = None

    # Begin encoder attn & enc/dec cross-attn fields...
    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

    def __post_init__(self):
        # Set during the execution of the first attention op.
        # It is a list because it is needed to set per prompt
        # when alibi slopes is used. It is because of the limitation
        # from xformer API.
        # will not appear in the __repr__ and __init__
        self.attn_bias: Optional[List[torch.Tensor]] = None
        self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
        self.cross_attn_bias: Optional[List[torch.Tensor]] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
        return ((self.encoder_seq_lens is not None)
                and (self.encoder_seq_lens_tensor is not None)
                and (self.max_encoder_seq_len is not None))

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
        return (self.is_all_encoder_attn_metadata_set
                and (self.cross_slot_mapping is not None)
                and (self.cross_block_tables is not None))

    @property
    def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
        if self.num_prefill_tokens == 0:
            return None
        return self

    @property
    def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
        if self.num_decode_tokens == 0:
            return None
        return self

    def get_seq_lens(
        self,
        attn_type: str,
    ):
        '''
        Extract appropriate sequence lengths from attention metadata
        according to attention type.

        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention

        Returns:
        * Appropriate sequence lengths tensor for query
        * Appropriate sequence lengths tensor for key & value
        '''

        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
            seq_lens_q = self.seq_lens
            seq_lens_kv = self.seq_lens
        elif attn_type == AttentionType.ENCODER:
            seq_lens_q = self.encoder_seq_lens
            seq_lens_kv = self.encoder_seq_lens
        elif attn_type == AttentionType.ENCODER_DECODER:
            seq_lens_q = self.seq_lens
            seq_lens_kv = self.encoder_seq_lens
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")
        return seq_lens_q, seq_lens_kv

    def get_attn_bias(
        self,
        attn_type: str,
    ) -> Optional[List[torch.Tensor]]:
        '''
        Extract appropriate attention bias from attention metadata
        according to attention type.

        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention

        Returns:
        * Appropriate attention bias value given the attention type
        '''

        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
            return self.attn_bias
        elif attn_type == AttentionType.ENCODER:
            return self.encoder_attn_bias
        elif attn_type == AttentionType.ENCODER_DECODER:
            return self.cross_attn_bias
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")

    def set_attn_bias(
        self,
        attn_bias: List[torch.Tensor],
        attn_type: str,
    ) -> None:
        '''
        Update appropriate attention bias field of attention metadata,
        according to attention type.

        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * attn_bias: The desired attention bias value
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention
        '''

        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
            self.attn_bias = attn_bias
        elif attn_type == AttentionType.ENCODER:
            self.encoder_attn_bias = attn_bias
        elif attn_type == AttentionType.ENCODER_DECODER:
            self.cross_attn_bias = attn_bias
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")

    def get_seq_len_block_table_args(
        self,
        attn_type: str,
    ) -> tuple:
        '''
        The particular choice of sequence-length- and block-table-related
        attributes which should be extracted from attn_metadata is dependent
        on the type of attention operation.

        Decoder attn -> select entirely decoder self-attention-related fields
        Encoder/decoder cross-attn -> select encoder sequence lengths &
                                    cross-attn block-tables fields
        Encoder attn -> select encoder sequence lengths fields & no block tables

        Arguments:

        * attn_metadata: Attention metadata structure associated with attention
        * is_prompt: True if prefill, False otherwise
        * attn_type: encoder attention, decoder self-attention,
                    encoder/decoder cross-attention

        Returns:

        * Appropriate sequence-lengths tensor
        * Appropriate max sequence-length scalar
        * Appropriate block tables (or None)
        '''

        if (attn_type == AttentionType.DECODER
                or attn_type == AttentionType.ENCODER_ONLY):
            # Decoder self-attention
            # Choose max_seq_len based on whether we are in prompt_run
            return (self.seq_lens_tensor, self.max_decode_seq_len,
                    self.block_tables)
        elif attn_type == AttentionType.ENCODER_DECODER:
            # Enc/dec cross-attention KVs match encoder sequence length;
            # cross-attention utilizes special "cross" block tables
            return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
                    self.cross_block_tables)
        elif attn_type == AttentionType.ENCODER:
            # No block tables associated with encoder attention
            return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
                    None)
        else:
            raise AttributeError(f"Invalid attention type {str(attn_type)}")

chunked_prefill instance-attribute

chunked_prefill: bool

cross_block_tables class-attribute instance-attribute

cross_block_tables: Optional[Tensor] = None

cross_slot_mapping class-attribute instance-attribute

cross_slot_mapping: Optional[Tensor] = None

decode_metadata property

decode_metadata: Optional[TorchSDPAMetadata]

encoder_seq_lens class-attribute instance-attribute

encoder_seq_lens: Optional[List[int]] = None

encoder_seq_lens_tensor class-attribute instance-attribute

encoder_seq_lens_tensor: Optional[Tensor] = None

is_all_cross_attn_metadata_set property

is_all_cross_attn_metadata_set

All attention metadata required for enc/dec cross-attention is set.

Superset of encoder attention required metadata.

is_all_encoder_attn_metadata_set property

is_all_encoder_attn_metadata_set

All attention metadata required for encoder attention is set.

kv_start_loc class-attribute instance-attribute

kv_start_loc: Optional[Tensor] = None

max_encoder_seq_len class-attribute instance-attribute

max_encoder_seq_len: Optional[int] = None

max_kv_len class-attribute instance-attribute

max_kv_len: Optional[int] = None

max_query_len class-attribute instance-attribute

max_query_len: Optional[int] = None

num_encoder_tokens class-attribute instance-attribute

num_encoder_tokens: Optional[int] = None

prefill_block_tables class-attribute instance-attribute

prefill_block_tables: Optional[Tensor] = None

prefill_metadata property

prefill_metadata: Optional[TorchSDPAMetadata]

prefill_query_start_loc class-attribute instance-attribute

prefill_query_start_loc: Optional[Tensor] = None

query_start_loc class-attribute instance-attribute

query_start_loc: Optional[Tensor] = None

seq_lens class-attribute instance-attribute

seq_lens: Optional[List[int]] = None

__init__

__init__(
    seq_lens_tensor: Optional[Tensor],
    max_decode_seq_len: int,
    block_tables: Optional[Tensor],
    num_prefills: int,
    num_prefill_tokens: int,
    num_decode_tokens: int,
    slot_mapping: Tensor,
    multi_modal_placeholder_index_maps: Optional[
        Dict[str, IndexMap]
    ],
    enable_kv_scales_calculation: bool,
    chunked_prefill: bool,
    seq_lens: Optional[List[int]] = None,
    max_query_len: Optional[int] = None,
    max_kv_len: Optional[int] = None,
    prefill_query_start_loc: Optional[Tensor] = None,
    kv_start_loc: Optional[Tensor] = None,
    prefill_block_tables: Optional[Tensor] = None,
    query_start_loc: Optional[Tensor] = None,
    encoder_seq_lens: Optional[List[int]] = None,
    encoder_seq_lens_tensor: Optional[Tensor] = None,
    max_encoder_seq_len: Optional[int] = None,
    num_encoder_tokens: Optional[int] = None,
    cross_slot_mapping: Optional[Tensor] = None,
    cross_block_tables: Optional[Tensor] = None,
) -> None

__post_init__

__post_init__()
Source code in vllm/attention/backends/torch_sdpa.py
def __post_init__(self):
    # Set during the execution of the first attention op.
    # It is a list because it is needed to set per prompt
    # when alibi slopes is used. It is because of the limitation
    # from xformer API.
    # will not appear in the __repr__ and __init__
    self.attn_bias: Optional[List[torch.Tensor]] = None
    self.encoder_attn_bias: Optional[List[torch.Tensor]] = None
    self.cross_attn_bias: Optional[List[torch.Tensor]] = None

get_attn_bias

get_attn_bias(attn_type: str) -> Optional[List[Tensor]]

Extract appropriate attention bias from attention metadata according to attention type.

Arguments:

  • attn_metadata: Attention metadata structure associated with attention
  • attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention

Returns: * Appropriate attention bias value given the attention type

Source code in vllm/attention/backends/torch_sdpa.py
def get_attn_bias(
    self,
    attn_type: str,
) -> Optional[List[torch.Tensor]]:
    '''
    Extract appropriate attention bias from attention metadata
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_type: encoder attention, decoder self-attention,
                encoder/decoder cross-attention

    Returns:
    * Appropriate attention bias value given the attention type
    '''

    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
        return self.attn_bias
    elif attn_type == AttentionType.ENCODER:
        return self.encoder_attn_bias
    elif attn_type == AttentionType.ENCODER_DECODER:
        return self.cross_attn_bias
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")

get_seq_len_block_table_args

get_seq_len_block_table_args(attn_type: str) -> tuple

The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent on the type of attention operation.

Decoder attn -> select entirely decoder self-attention-related fields Encoder/decoder cross-attn -> select encoder sequence lengths & cross-attn block-tables fields Encoder attn -> select encoder sequence lengths fields & no block tables

Arguments:

  • attn_metadata: Attention metadata structure associated with attention
  • is_prompt: True if prefill, False otherwise
  • attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention

Returns:

  • Appropriate sequence-lengths tensor
  • Appropriate max sequence-length scalar
  • Appropriate block tables (or None)
Source code in vllm/attention/backends/torch_sdpa.py
def get_seq_len_block_table_args(
    self,
    attn_type: str,
) -> tuple:
    '''
    The particular choice of sequence-length- and block-table-related
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths &
                                cross-attn block-tables fields
    Encoder attn -> select encoder sequence lengths fields & no block tables

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * is_prompt: True if prefill, False otherwise
    * attn_type: encoder attention, decoder self-attention,
                encoder/decoder cross-attention

    Returns:

    * Appropriate sequence-lengths tensor
    * Appropriate max sequence-length scalar
    * Appropriate block tables (or None)
    '''

    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        return (self.seq_lens_tensor, self.max_decode_seq_len,
                self.block_tables)
    elif attn_type == AttentionType.ENCODER_DECODER:
        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
                self.cross_block_tables)
    elif attn_type == AttentionType.ENCODER:
        # No block tables associated with encoder attention
        return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len,
                None)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")

get_seq_lens

get_seq_lens(attn_type: str)

Extract appropriate sequence lengths from attention metadata according to attention type.

Arguments:

  • attn_metadata: Attention metadata structure associated with attention
  • attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention

Returns: * Appropriate sequence lengths tensor for query * Appropriate sequence lengths tensor for key & value

Source code in vllm/attention/backends/torch_sdpa.py
def get_seq_lens(
    self,
    attn_type: str,
):
    '''
    Extract appropriate sequence lengths from attention metadata
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_type: encoder attention, decoder self-attention,
                encoder/decoder cross-attention

    Returns:
    * Appropriate sequence lengths tensor for query
    * Appropriate sequence lengths tensor for key & value
    '''

    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
        seq_lens_q = self.seq_lens
        seq_lens_kv = self.seq_lens
    elif attn_type == AttentionType.ENCODER:
        seq_lens_q = self.encoder_seq_lens
        seq_lens_kv = self.encoder_seq_lens
    elif attn_type == AttentionType.ENCODER_DECODER:
        seq_lens_q = self.seq_lens
        seq_lens_kv = self.encoder_seq_lens
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")
    return seq_lens_q, seq_lens_kv

set_attn_bias

set_attn_bias(
    attn_bias: List[Tensor], attn_type: str
) -> None

Update appropriate attention bias field of attention metadata, according to attention type.

Arguments:

  • attn_metadata: Attention metadata structure associated with attention
  • attn_bias: The desired attention bias value
  • attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention
Source code in vllm/attention/backends/torch_sdpa.py
def set_attn_bias(
    self,
    attn_bias: List[torch.Tensor],
    attn_type: str,
) -> None:
    '''
    Update appropriate attention bias field of attention metadata,
    according to attention type.

    Arguments:

    * attn_metadata: Attention metadata structure associated with attention
    * attn_bias: The desired attention bias value
    * attn_type: encoder attention, decoder self-attention,
                encoder/decoder cross-attention
    '''

    if (attn_type == AttentionType.DECODER
            or attn_type == AttentionType.ENCODER_ONLY):
        self.attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER:
        self.encoder_attn_bias = attn_bias
    elif attn_type == AttentionType.ENCODER_DECODER:
        self.cross_attn_bias = attn_bias
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")

TorchSDPAMetadataBuilder

Bases: AttentionMetadataBuilder[TorchSDPAMetadata]

Source code in vllm/attention/backends/torch_sdpa.py
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):

    def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
        self.chunked_prefill = input_builder.chunked_prefill
        self.input_builder = input_builder

    def prepare(self):
        self.input_data = self.input_builder.input_data

    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
        input_data = self.input_data
        prefill_seq_lens = seq_lens[0:input_data.num_prefills]
        prefill_query_lens = query_lens[0:input_data.num_prefills]
        slot_mapping = torch.tensor(input_data.slot_mapping,
                                    dtype=torch.long,
                                    device="cpu")

        # For chunked-prefill
        if self.chunked_prefill and input_data.num_prefill_tokens != 0:
            prefill_block_tables = make_tensor_with_pad(
                self.input_data.prefill_block_tables,
                pad=0,
                dtype=torch.int32,
                device="cpu",
            )
            query_lens_tensor = torch.tensor(prefill_query_lens,
                                             dtype=torch.int32,
                                             device="cpu")
            kv_lens_tensor = torch.tensor(prefill_seq_lens,
                                          dtype=torch.int32,
                                          device="cpu")
            query_start_loc = torch.zeros(input_data.num_prefills + 1,
                                          dtype=torch.int32,
                                          device="cpu")
            kv_start_loc = torch.zeros(input_data.num_prefills + 1,
                                       dtype=torch.int32,
                                       device="cpu")
            torch.cumsum(query_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=query_start_loc[1:])
            torch.cumsum(kv_lens_tensor,
                         dim=0,
                         dtype=torch.int32,
                         out=kv_start_loc[1:])
            max_query_len = max(prefill_query_lens)
            max_kv_len = max(prefill_seq_lens)
        else:
            prefill_block_tables = None
            query_start_loc = None
            kv_start_loc = None
            max_query_len = None
            max_kv_len = None

        # For paged attention
        if input_data.num_decode_tokens != 0:
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[input_data.num_prefills:],
                dtype=torch.int32,
                device="cpu",
            )
            block_tables = make_tensor_with_pad(
                self.input_data.decode_block_tables,
                pad=0,
                dtype=torch.int32,
                device="cpu",
            )
        else:
            block_tables = torch.tensor([])
            seq_lens_tensor = torch.tensor(
                input_data.seq_lens[:input_data.num_prefills],
                dtype=torch.int32,
                device="cpu",
            )

        # For multi-modal models
        placeholder_index_maps = None
        if len(input_data.multi_modal_inputs_list) != 0:
            placeholder_index_maps = {
                modality: placeholder_map.index_map()
                for modality, placeholder_map in
                input_data.multi_modal_placeholder_maps.items()
            }

        attn_metadata = TorchSDPAMetadata(
            chunked_prefill=self.chunked_prefill,
            seq_lens=prefill_seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
            max_kv_len=max_kv_len,
            prefill_query_start_loc=query_start_loc,
            kv_start_loc=kv_start_loc,
            max_decode_seq_len=input_data.max_decode_seq_len,
            num_prefills=input_data.num_prefills,
            num_prefill_tokens=input_data.num_prefill_tokens,
            num_decode_tokens=input_data.num_decode_tokens,
            block_tables=block_tables,
            prefill_block_tables=prefill_block_tables,
            slot_mapping=slot_mapping,
            multi_modal_placeholder_index_maps=placeholder_index_maps,
            enable_kv_scales_calculation=False,
        )

        return attn_metadata

chunked_prefill instance-attribute

chunked_prefill = chunked_prefill

input_builder instance-attribute

input_builder = input_builder

__init__

__init__(input_builder: ModelInputForCPUBuilder) -> None
Source code in vllm/attention/backends/torch_sdpa.py
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
    self.chunked_prefill = input_builder.chunked_prefill
    self.input_builder = input_builder

build

build(
    seq_lens: List[int],
    query_lens: List[int],
    cuda_graph_pad_size: int,
    batch_size: int,
) -> TorchSDPAMetadata
Source code in vllm/attention/backends/torch_sdpa.py
def build(self, seq_lens: List[int], query_lens: List[int],
          cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
    input_data = self.input_data
    prefill_seq_lens = seq_lens[0:input_data.num_prefills]
    prefill_query_lens = query_lens[0:input_data.num_prefills]
    slot_mapping = torch.tensor(input_data.slot_mapping,
                                dtype=torch.long,
                                device="cpu")

    # For chunked-prefill
    if self.chunked_prefill and input_data.num_prefill_tokens != 0:
        prefill_block_tables = make_tensor_with_pad(
            self.input_data.prefill_block_tables,
            pad=0,
            dtype=torch.int32,
            device="cpu",
        )
        query_lens_tensor = torch.tensor(prefill_query_lens,
                                         dtype=torch.int32,
                                         device="cpu")
        kv_lens_tensor = torch.tensor(prefill_seq_lens,
                                      dtype=torch.int32,
                                      device="cpu")
        query_start_loc = torch.zeros(input_data.num_prefills + 1,
                                      dtype=torch.int32,
                                      device="cpu")
        kv_start_loc = torch.zeros(input_data.num_prefills + 1,
                                   dtype=torch.int32,
                                   device="cpu")
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=torch.int32,
                     out=query_start_loc[1:])
        torch.cumsum(kv_lens_tensor,
                     dim=0,
                     dtype=torch.int32,
                     out=kv_start_loc[1:])
        max_query_len = max(prefill_query_lens)
        max_kv_len = max(prefill_seq_lens)
    else:
        prefill_block_tables = None
        query_start_loc = None
        kv_start_loc = None
        max_query_len = None
        max_kv_len = None

    # For paged attention
    if input_data.num_decode_tokens != 0:
        seq_lens_tensor = torch.tensor(
            input_data.seq_lens[input_data.num_prefills:],
            dtype=torch.int32,
            device="cpu",
        )
        block_tables = make_tensor_with_pad(
            self.input_data.decode_block_tables,
            pad=0,
            dtype=torch.int32,
            device="cpu",
        )
    else:
        block_tables = torch.tensor([])
        seq_lens_tensor = torch.tensor(
            input_data.seq_lens[:input_data.num_prefills],
            dtype=torch.int32,
            device="cpu",
        )

    # For multi-modal models
    placeholder_index_maps = None
    if len(input_data.multi_modal_inputs_list) != 0:
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            input_data.multi_modal_placeholder_maps.items()
        }

    attn_metadata = TorchSDPAMetadata(
        chunked_prefill=self.chunked_prefill,
        seq_lens=prefill_seq_lens,
        seq_lens_tensor=seq_lens_tensor,
        max_query_len=max_query_len,
        max_kv_len=max_kv_len,
        prefill_query_start_loc=query_start_loc,
        kv_start_loc=kv_start_loc,
        max_decode_seq_len=input_data.max_decode_seq_len,
        num_prefills=input_data.num_prefills,
        num_prefill_tokens=input_data.num_prefill_tokens,
        num_decode_tokens=input_data.num_decode_tokens,
        block_tables=block_tables,
        prefill_block_tables=prefill_block_tables,
        slot_mapping=slot_mapping,
        multi_modal_placeholder_index_maps=placeholder_index_maps,
        enable_kv_scales_calculation=False,
    )

    return attn_metadata

prepare

prepare()
Source code in vllm/attention/backends/torch_sdpa.py
def prepare(self):
    self.input_data = self.input_builder.input_data

_make_alibi_bias

_make_alibi_bias(
    alibi_slopes: Tensor, dtype: dtype, seq_lens: List[int]
) -> List[Tensor]
Source code in vllm/attention/backends/torch_sdpa.py
def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
    dtype: torch.dtype,
    seq_lens: List[int],
) -> List[torch.Tensor]:
    attn_biases: List[torch.Tensor] = []
    for seq_len in seq_lens:
        bias = torch.arange(seq_len, dtype=dtype)
        # NOTE(zhuohan): HF uses
        #     `bias = bias[None, :].repeat(seq_len, 1)`
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        bias = bias[None, :] - bias[:, None]

        num_heads = alibi_slopes.shape[0]
        bias = bias[None, :].repeat((num_heads, 1, 1))
        bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
        inf_mask = torch.empty(
            (1, seq_len, seq_len),
            dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
        attn_biases.append((bias + inf_mask).to(dtype))

    return attn_biases

_make_sliding_window_bias

_make_sliding_window_bias(
    seq_lens: List[int],
    window_size: Optional[int],
    dtype: dtype,
) -> List[Tensor]
Source code in vllm/attention/backends/torch_sdpa.py
def _make_sliding_window_bias(
    seq_lens: List[int],
    window_size: Optional[int],
    dtype: torch.dtype,
) -> List[torch.Tensor]:
    attn_biases: List[torch.Tensor] = []
    for seq_len in seq_lens:
        tensor = torch.full(
            (1, seq_len, seq_len),
            dtype=dtype,
            fill_value=1,
        )
        shift = 0
        mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore
        if window_size is not None:
            mask = torch.triu(mask, diagonal=shift - window_size + 1)
        mask = torch.log(mask)
        attn_biases.append(mask.to(dtype))

    return attn_biases