Skip to content

vllm.attention.backends.hpu_attn

logger module-attribute

logger = init_logger(__name__)

HPUAttentionBackend

Bases: AttentionBackend

Source code in vllm/attention/backends/hpu_attn.py
class HPUAttentionBackend(AttentionBackend):

    @staticmethod
    def get_name() -> str:
        return "HPU_ATTN"

    @staticmethod
    def get_impl_cls() -> Type["HPUAttentionImpl"]:
        return HPUAttentionImpl

    @staticmethod
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return HPUAttentionMetadata

    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                    num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dsts: torch.Tensor,
    ) -> None:
        HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dsts: torch.Tensor,
    ) -> None:
        HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)

copy_blocks staticmethod

copy_blocks(
    kv_caches: List[Tensor], src_to_dsts: Tensor
) -> None
Source code in vllm/attention/backends/hpu_attn.py
@staticmethod
def copy_blocks(
    kv_caches: List[torch.Tensor],
    src_to_dsts: torch.Tensor,
) -> None:
    HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)

get_impl_cls staticmethod

get_impl_cls() -> Type[HPUAttentionImpl]
Source code in vllm/attention/backends/hpu_attn.py
@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
    return HPUAttentionImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> Tuple[int, ...]
Source code in vllm/attention/backends/hpu_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
) -> Tuple[int, ...]:
    return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> Type[AttentionMetadata]
Source code in vllm/attention/backends/hpu_attn.py
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
    return HPUAttentionMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/attention/backends/hpu_attn.py
@staticmethod
def get_name() -> str:
    return "HPU_ATTN"

get_state_cls staticmethod

get_state_cls() -> Type[CommonAttentionState]
Source code in vllm/attention/backends/hpu_attn.py
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
    return CommonAttentionState

swap_blocks staticmethod

swap_blocks(
    src_kv_cache: Tensor,
    dst_kv_cache: Tensor,
    src_to_dsts: Tensor,
) -> None
Source code in vllm/attention/backends/hpu_attn.py
@staticmethod
def swap_blocks(
    src_kv_cache: torch.Tensor,
    dst_kv_cache: torch.Tensor,
    src_to_dsts: torch.Tensor,
) -> None:
    HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

HPUAttentionImpl

Bases: AttentionImpl, Module

If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

Otherwise, the layout is as follows: |<----------------- num_decode_tokens ------------------>| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding.

The prompts might have different lengths, while the generation tokens always have length 1.

Source code in vllm/attention/backends/hpu_attn.py
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prefill_tokens ----------------->|
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

    Otherwise, the layout is as follows:
    |<----------------- num_decode_tokens ------------------>|
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        blocksparse_params: Optional[Dict[str, Any]] = None,
        max_seq_len: int = 4096,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: Optional[str] = None,
        use_irope: bool = False,
    ) -> None:
        super(AttentionImpl, self).__init__()
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported in V0.")
        if use_irope:
            logger.warning_once(
                "Using irope in HPU is not supported yet, it will fall back "
                "to global attention for long context.")
        self.kv_cache_dtype = kv_cache_dtype
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.matmul_qk = Matmul()
        self.softmax = Softmax()
        self.matmul_av = Matmul()
        self.batch2block_matmul = Matmul()
        self.block2batch_matmul = Matmul()
        self.k_cache = VLLMKVCache()
        self.v_cache = VLLMKVCache()
        self.fused_scaled_dot_product_attention = kernels.fsdpa()

        self.prefill_impl = 'naive'
        if "flex_attention" in enabled_flags():
            self.prefill_impl = 'flex'
        if "fsdpa" in enabled_flags():
            assert alibi_slopes is None, \
                'Prefill with FusedSDPA not supported with alibi slopes!'
            self.prefill_impl = 'fsdpa'

        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.sliding_window = sliding_window
        self.alibi_slopes = alibi_slopes
        if alibi_slopes is not None:
            alibi_slopes_tensor = torch.tensor(alibi_slopes,
                                               dtype=torch.bfloat16)
            self.alibi_slopes = alibi_slopes_tensor
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        if self.prefill_impl == 'fsdpa':
            assert alibi_slopes is None, \
                'Prefill with FusedSDPA not supported with alibi slopes!'

        supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {supported_head_sizes}.")

        self.attn_type = attn_type
        if self.attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "HPUAttentionImpl")

        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "HPUAttention with FP8 KV cache not yet supported")

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: HPUAttentionMetadata,
        output: Optional[torch.Tensor] = None,
        output_scale: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass with xFormers and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for HPUAttentionImpl")

        batch_size, seq_len, hidden_size = query.shape
        _, seq_len_kv, _ = key.shape

        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
        block_indices = attn_metadata.block_indices
        block_offsets = attn_metadata.block_offsets
        key_cache = None
        value_cache = None
        if attn_metadata.is_prompt and self.attn_type \
           is not AttentionType.ENCODER_ONLY:
            key = key.unflatten(0, (block_indices.size(0), -1))
            value = value.unflatten(0, (block_indices.size(0), -1))
        if kv_cache is not None and isinstance(kv_cache, tuple):
            key_cache, value_cache = HPUPagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            key_cache = self.k_cache(key, key_cache, block_indices,
                                     block_offsets)
            value_cache = self.v_cache(value, value_cache, block_indices,
                                       block_offsets)

        if attn_metadata.is_prompt:
            # Prompt run.
            query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
            kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
                        self.head_size)

            attn_bias = attn_metadata.attn_bias
            if attn_bias is not None and self.alibi_slopes is not None:
                position_bias = _make_alibi_bias(self.alibi_slopes,
                                                 self.num_kv_heads,
                                                 attn_bias.dtype,
                                                 attn_bias.shape[-1])
                attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
                attn_bias.add_(position_bias)

            block_list = attn_metadata.block_list if attn_metadata \
                and attn_metadata.block_list is not None else None

            out = ops.prompt_attention(
                impl=self.prefill_impl,
                query=query.view(query_shape),
                key=key.view(kv_shape),
                value=value.view(kv_shape),
                is_causal=True,
                attn_bias=attn_bias,
                valid_seq_lengths=attn_metadata.seq_lens_tensor,
                **self.common_attention_args(block_list, key_cache,
                                             value_cache))
            output = out.reshape(batch_size, seq_len, hidden_size)
        else:
            # Decoding run.
            output = HPUPagedAttention.forward_decode(
                query=query,
                block_mapping=attn_metadata.block_mapping,
                block_bias=attn_metadata.attn_bias,
                block_groups=attn_metadata.block_groups,
                **self.common_attention_args(attn_metadata.block_list,
                                             key_cache, value_cache))
        # Reshape the output tensor.
        return output.view(batch_size, seq_len, hidden_size)

    def common_attention_args(self,
                              block_list=None,
                              key_cache=None,
                              value_cache=None):
        fsdpa_op = self.fused_scaled_dot_product_attention.apply \
            if self.fused_scaled_dot_product_attention is not None else None
        return {
            'scale': self.scale,
            'matmul_qk_op': self.matmul_qk,
            'matmul_av_op': self.matmul_av,
            'batch2block_matmul_op': self.batch2block_matmul,
            'block2batch_matmul_op': self.block2batch_matmul,
            'fsdpa_op': fsdpa_op,
            'keys_fetch_func': self.k_cache.fetch_from_cache,
            'values_fetch_func': self.v_cache.fetch_from_cache,
            'softmax_op': self.softmax,
            'block_list': block_list,
            'key_cache': key_cache,
            'value_cache': value_cache,
        }

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

attn_type instance-attribute

attn_type = attn_type

batch2block_matmul instance-attribute

batch2block_matmul = Matmul()

block2batch_matmul instance-attribute

block2batch_matmul = Matmul()

fused_scaled_dot_product_attention instance-attribute

fused_scaled_dot_product_attention = fsdpa()

head_size instance-attribute

head_size = head_size

k_cache instance-attribute

k_cache = VLLMKVCache()

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

matmul_av instance-attribute

matmul_av = Matmul()

matmul_qk instance-attribute

matmul_qk = Matmul()

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = (
    num_heads if num_kv_heads is None else num_kv_heads
)

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

prefill_impl instance-attribute

prefill_impl = 'naive'

scale instance-attribute

scale = float(scale)

sliding_window instance-attribute

sliding_window = sliding_window

softmax instance-attribute

softmax = Softmax()

v_cache instance-attribute

v_cache = VLLMKVCache()

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[List[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[Dict[str, Any]] = None,
    max_seq_len: int = 4096,
    attn_type: str = DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None
Source code in vllm/attention/backends/hpu_attn.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: Optional[List[float]],
    sliding_window: Optional[int],
    kv_cache_dtype: str,
    blocksparse_params: Optional[Dict[str, Any]] = None,
    max_seq_len: int = 4096,
    attn_type: str = AttentionType.DECODER,
    kv_sharing_target_layer_name: Optional[str] = None,
    use_irope: bool = False,
) -> None:
    super(AttentionImpl, self).__init__()
    if kv_sharing_target_layer_name is not None:
        raise NotImplementedError("KV sharing is not supported in V0.")
    if use_irope:
        logger.warning_once(
            "Using irope in HPU is not supported yet, it will fall back "
            "to global attention for long context.")
    self.kv_cache_dtype = kv_cache_dtype
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.matmul_qk = Matmul()
    self.softmax = Softmax()
    self.matmul_av = Matmul()
    self.batch2block_matmul = Matmul()
    self.block2batch_matmul = Matmul()
    self.k_cache = VLLMKVCache()
    self.v_cache = VLLMKVCache()
    self.fused_scaled_dot_product_attention = kernels.fsdpa()

    self.prefill_impl = 'naive'
    if "flex_attention" in enabled_flags():
        self.prefill_impl = 'flex'
    if "fsdpa" in enabled_flags():
        assert alibi_slopes is None, \
            'Prefill with FusedSDPA not supported with alibi slopes!'
        self.prefill_impl = 'fsdpa'

    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.sliding_window = sliding_window
    self.alibi_slopes = alibi_slopes
    if alibi_slopes is not None:
        alibi_slopes_tensor = torch.tensor(alibi_slopes,
                                           dtype=torch.bfloat16)
        self.alibi_slopes = alibi_slopes_tensor
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    if self.prefill_impl == 'fsdpa':
        assert alibi_slopes is None, \
            'Prefill with FusedSDPA not supported with alibi slopes!'

    supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        raise ValueError(
            f"Head size {head_size} is not supported by PagedAttention. "
            f"Supported head sizes are: {supported_head_sizes}.")

    self.attn_type = attn_type
    if self.attn_type != AttentionType.DECODER:
        raise NotImplementedError("Encoder self-attention and "
                                  "encoder/decoder cross-attention "
                                  "are not implemented for "
                                  "HPUAttentionImpl")

    if is_quantized_kv_cache(self.kv_cache_dtype):
        raise NotImplementedError(
            "HPUAttention with FP8 KV cache not yet supported")

common_attention_args

common_attention_args(
    block_list=None, key_cache=None, value_cache=None
)
Source code in vllm/attention/backends/hpu_attn.py
def common_attention_args(self,
                          block_list=None,
                          key_cache=None,
                          value_cache=None):
    fsdpa_op = self.fused_scaled_dot_product_attention.apply \
        if self.fused_scaled_dot_product_attention is not None else None
    return {
        'scale': self.scale,
        'matmul_qk_op': self.matmul_qk,
        'matmul_av_op': self.matmul_av,
        'batch2block_matmul_op': self.batch2block_matmul,
        'block2batch_matmul_op': self.block2batch_matmul,
        'fsdpa_op': fsdpa_op,
        'keys_fetch_func': self.k_cache.fetch_from_cache,
        'values_fetch_func': self.v_cache.fetch_from_cache,
        'softmax_op': self.softmax,
        'block_list': block_list,
        'key_cache': key_cache,
        'value_cache': value_cache,
    }

forward

forward(
    layer: AttentionLayer,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[Tensor] = None,
    output_scale: Optional[Tensor] = None,
) -> Tensor

Forward pass with xFormers and PagedAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads * head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads * head_size]

required
attn_metadata HPUAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size]

Source code in vllm/attention/backends/hpu_attn.py
def forward(
    self,
    layer: AttentionLayer,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: HPUAttentionMetadata,
    output: Optional[torch.Tensor] = None,
    output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Forward pass with xFormers and PagedAttention.

    Args:
        query: shape = [num_tokens, num_heads * head_size]
        key: shape = [num_tokens, num_kv_heads * head_size]
        value: shape = [num_tokens, num_kv_heads * head_size]
        kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    """
    if output_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported"
            " for HPUAttentionImpl")

    batch_size, seq_len, hidden_size = query.shape
    _, seq_len_kv, _ = key.shape

    key = key.view(-1, self.num_kv_heads, self.head_size)
    value = value.view(-1, self.num_kv_heads, self.head_size)
    block_indices = attn_metadata.block_indices
    block_offsets = attn_metadata.block_offsets
    key_cache = None
    value_cache = None
    if attn_metadata.is_prompt and self.attn_type \
       is not AttentionType.ENCODER_ONLY:
        key = key.unflatten(0, (block_indices.size(0), -1))
        value = value.unflatten(0, (block_indices.size(0), -1))
    if kv_cache is not None and isinstance(kv_cache, tuple):
        key_cache, value_cache = HPUPagedAttention.split_kv_cache(
            kv_cache, self.num_kv_heads, self.head_size)

        # Reshape the input keys and values and store them in the cache.
        # If kv_cache is not provided, the new key and value tensors are
        # not cached. This happens during the initial memory profiling run.
        key_cache = self.k_cache(key, key_cache, block_indices,
                                 block_offsets)
        value_cache = self.v_cache(value, value_cache, block_indices,
                                   block_offsets)

    if attn_metadata.is_prompt:
        # Prompt run.
        query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
        kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
                    self.head_size)

        attn_bias = attn_metadata.attn_bias
        if attn_bias is not None and self.alibi_slopes is not None:
            position_bias = _make_alibi_bias(self.alibi_slopes,
                                             self.num_kv_heads,
                                             attn_bias.dtype,
                                             attn_bias.shape[-1])
            attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
            attn_bias.add_(position_bias)

        block_list = attn_metadata.block_list if attn_metadata \
            and attn_metadata.block_list is not None else None

        out = ops.prompt_attention(
            impl=self.prefill_impl,
            query=query.view(query_shape),
            key=key.view(kv_shape),
            value=value.view(kv_shape),
            is_causal=True,
            attn_bias=attn_bias,
            valid_seq_lengths=attn_metadata.seq_lens_tensor,
            **self.common_attention_args(block_list, key_cache,
                                         value_cache))
        output = out.reshape(batch_size, seq_len, hidden_size)
    else:
        # Decoding run.
        output = HPUPagedAttention.forward_decode(
            query=query,
            block_mapping=attn_metadata.block_mapping,
            block_bias=attn_metadata.attn_bias,
            block_groups=attn_metadata.block_groups,
            **self.common_attention_args(attn_metadata.block_list,
                                         key_cache, value_cache))
    # Reshape the output tensor.
    return output.view(batch_size, seq_len, hidden_size)

HPUAttentionMetadata dataclass

Bases: HPUPagedAttentionMetadata, AttentionMetadata

Metadata for HPUAttentionbackend.

Source code in vllm/attention/backends/hpu_attn.py
@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
    """Metadata for HPUAttentionbackend."""
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
    attn_bias: Optional[torch.Tensor]
    seq_lens_tensor: Optional[torch.Tensor]
    context_lens_tensor: Optional[torch.Tensor]

attn_bias instance-attribute

attn_bias: Optional[Tensor]

context_lens_tensor instance-attribute

context_lens_tensor: Optional[Tensor]

is_prompt instance-attribute

is_prompt: bool

seq_lens_tensor instance-attribute

seq_lens_tensor: Optional[Tensor]

__init__

__init__(
    num_prefills: int,
    num_prefill_tokens: int,
    num_decode_tokens: int,
    slot_mapping: Tensor,
    multi_modal_placeholder_index_maps: Optional[
        Dict[str, IndexMap]
    ],
    enable_kv_scales_calculation: bool,
    block_list: Optional[Tensor],
    block_mapping: Optional[Tensor],
    block_usage: Optional[Tensor],
    block_indices: Optional[Tensor],
    block_offsets: Optional[Tensor],
    block_groups: Optional[Tensor],
    is_prompt: bool,
    attn_bias: Optional[Tensor],
    seq_lens_tensor: Optional[Tensor],
    context_lens_tensor: Optional[Tensor],
) -> None

_make_alibi_bias

_make_alibi_bias(
    alibi_slopes: Tensor,
    num_kv_heads: int,
    dtype: dtype,
    seq_len: int,
) -> Tensor
Source code in vllm/attention/backends/hpu_attn.py
def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
    num_kv_heads: int,
    dtype: torch.dtype,
    seq_len: int,
) -> torch.Tensor:
    bias = torch.arange(seq_len, dtype=dtype)
    # NOTE(zhuohan): HF uses
    #     `bias = bias[None, :].repeat(seq_len, 1)`
    # here. We find that both biases give the same results, but
    # the bias below more accurately follows the original ALiBi
    # paper.
    # Calculate a matrix where each element represents ith element- jth
    # element.
    bias = bias[None, :] - bias[:, None]

    padded_len = (seq_len + 7) // 8 * 8
    num_heads = alibi_slopes.shape[0]
    bias = torch.empty(
        1,  # batch size
        num_heads,
        seq_len,
        padded_len,
        device=alibi_slopes.device,
        dtype=dtype,
    )[:, :, :, :seq_len].copy_(bias)
    bias.mul_(alibi_slopes[:, None, None])
    if num_heads != num_kv_heads:
        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
    return bias