Skip to content

vllm.attention.ops.ipex_attn

PagedAttention module-attribute

PagedAttention = (
    _IPEXPagedAttention if _use_ipex else _PagedAttention
)

_use_ipex module-attribute

_use_ipex = True

_IPEXPagedAttention

Bases: _PagedAttention

Source code in vllm/attention/ops/ipex_attn.py
class _IPEXPagedAttention(_PagedAttention):

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
        *args,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
        return key_cache, value_cache

    @staticmethod
    def write_to_paged_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
        *args,
    ) -> None:
        ipex_modules.PagedAttention.reshape_and_cache(
            key, value, key_cache, value_cache,
            slot_mapping.flatten().int())

    @staticmethod
    def forward_decode(
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        max_context_len: int,
        kv_cache_dtype: str,
        num_kv_heads: int,
        scale: float,
        alibi_slopes: Optional[torch.Tensor],
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
        *args,
    ) -> None:
        block_size = value_cache.shape[2]
        head_mapping = torch.arange(
            0,
            num_kv_heads,
            device="cpu",
            dtype=torch.int32,
        ).view(num_kv_heads,
               1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
        ipex_modules.PagedAttention.single_query_cached_kv_attention(
            output, query.contiguous(), key_cache, value_cache, head_mapping,
            scale, block_tables, context_lens, block_size, max_context_len,
            alibi_slopes)

forward_decode staticmethod

forward_decode(
    output: Tensor,
    query: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    block_tables: Tensor,
    context_lens: Tensor,
    max_context_len: int,
    kv_cache_dtype: str,
    num_kv_heads: int,
    scale: float,
    alibi_slopes: Optional[Tensor],
    k_scale: Tensor,
    v_scale: Tensor,
    *args,
) -> None
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def forward_decode(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
    max_context_len: int,
    kv_cache_dtype: str,
    num_kv_heads: int,
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    *args,
) -> None:
    block_size = value_cache.shape[2]
    head_mapping = torch.arange(
        0,
        num_kv_heads,
        device="cpu",
        dtype=torch.int32,
    ).view(num_kv_heads,
           1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
    ipex_modules.PagedAttention.single_query_cached_kv_attention(
        output, query.contiguous(), key_cache, value_cache, head_mapping,
        scale, block_tables, context_lens, block_size, max_context_len,
        alibi_slopes)

split_kv_cache staticmethod

split_kv_cache(
    kv_cache: Tensor,
    num_kv_heads: int,
    head_size: int,
    *args,
) -> Tuple[Tensor, Tensor]
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def split_kv_cache(
    kv_cache: torch.Tensor,
    num_kv_heads: int,
    head_size: int,
    *args,
) -> Tuple[torch.Tensor, torch.Tensor]:
    num_blocks = kv_cache.shape[1]

    key_cache = kv_cache[0]
    key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
    value_cache = kv_cache[1]
    value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
    return key_cache, value_cache

write_to_paged_cache staticmethod

write_to_paged_cache(
    key: Tensor,
    value: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    slot_mapping: Tensor,
    kv_cache_dtype: str,
    k_scale: Tensor,
    v_scale: Tensor,
    *args,
) -> None
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def write_to_paged_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    *args,
) -> None:
    ipex_modules.PagedAttention.reshape_and_cache(
        key, value, key_cache, value_cache,
        slot_mapping.flatten().int())

_PagedAttention

Source code in vllm/attention/ops/ipex_attn.py
class _PagedAttention:

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [32, 64, 80, 96, 112, 128, 192, 256]

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        *args,
    ) -> Tuple[int, ...]:
        return 2, num_blocks, block_size * num_kv_heads * head_size

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
        *args,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = 16 // kv_cache.element_size()
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                                   -1, x)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
        return key_cache, value_cache

    @staticmethod
    def write_to_paged_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
        *args,
    ) -> None:
        ops.reshape_and_cache(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping.flatten(),
            kv_cache_dtype,
            k_scale,
            v_scale,
        )

    @staticmethod
    def forward_decode(
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
        context_lens: torch.Tensor,
        max_context_len: int,
        kv_cache_dtype: str,
        num_kv_heads: int,
        scale: float,
        alibi_slopes: Optional[torch.Tensor],
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
        *args,
    ) -> None:
        tp_rank: int = 0
        blocksparse_local_blocks: int = 0
        blocksparse_vert_stride: int = 0
        blocksparse_block_size: int = 64
        blocksparse_head_sliding_step: int = 0
        block_size = value_cache.shape[3]

        ops.paged_attention_v1(
            output,
            query,
            key_cache,
            value_cache,
            num_kv_heads,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
            kv_cache_dtype,
            k_scale,
            v_scale,
            tp_rank,
            blocksparse_local_blocks,
            blocksparse_vert_stride,
            blocksparse_block_size,
            blocksparse_head_sliding_step,
        )

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: torch.Tensor,
        *args,
    ) -> None:
        key_caches = [kv_cache[0] for kv_cache in kv_caches]
        value_caches = [kv_cache[1] for kv_cache in kv_caches]
        ops.copy_blocks(key_caches, value_caches, src_to_dists)

copy_blocks staticmethod

copy_blocks(
    kv_caches: List[Tensor], src_to_dists: Tensor, *args
) -> None
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def copy_blocks(
    kv_caches: List[torch.Tensor],
    src_to_dists: torch.Tensor,
    *args,
) -> None:
    key_caches = [kv_cache[0] for kv_cache in kv_caches]
    value_caches = [kv_cache[1] for kv_cache in kv_caches]
    ops.copy_blocks(key_caches, value_caches, src_to_dists)

forward_decode staticmethod

forward_decode(
    output: Tensor,
    query: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    block_tables: Tensor,
    context_lens: Tensor,
    max_context_len: int,
    kv_cache_dtype: str,
    num_kv_heads: int,
    scale: float,
    alibi_slopes: Optional[Tensor],
    k_scale: Tensor,
    v_scale: Tensor,
    *args,
) -> None
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def forward_decode(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
    max_context_len: int,
    kv_cache_dtype: str,
    num_kv_heads: int,
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    *args,
) -> None:
    tp_rank: int = 0
    blocksparse_local_blocks: int = 0
    blocksparse_vert_stride: int = 0
    blocksparse_block_size: int = 64
    blocksparse_head_sliding_step: int = 0
    block_size = value_cache.shape[3]

    ops.paged_attention_v1(
        output,
        query,
        key_cache,
        value_cache,
        num_kv_heads,
        scale,
        block_tables,
        context_lens,
        block_size,
        max_context_len,
        alibi_slopes,
        kv_cache_dtype,
        k_scale,
        v_scale,
        tp_rank,
        blocksparse_local_blocks,
        blocksparse_vert_stride,
        blocksparse_block_size,
        blocksparse_head_sliding_step,
    )

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    *args,
) -> Tuple[int, ...]
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    *args,
) -> Tuple[int, ...]:
    return 2, num_blocks, block_size * num_kv_heads * head_size

get_supported_head_sizes staticmethod

get_supported_head_sizes() -> List[int]
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def get_supported_head_sizes() -> List[int]:
    return [32, 64, 80, 96, 112, 128, 192, 256]

split_kv_cache staticmethod

split_kv_cache(
    kv_cache: Tensor,
    num_kv_heads: int,
    head_size: int,
    *args,
) -> Tuple[Tensor, Tensor]
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def split_kv_cache(
    kv_cache: torch.Tensor,
    num_kv_heads: int,
    head_size: int,
    *args,
) -> Tuple[torch.Tensor, torch.Tensor]:
    x = 16 // kv_cache.element_size()
    num_blocks = kv_cache.shape[1]

    key_cache = kv_cache[0]
    key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                               -1, x)
    value_cache = kv_cache[1]
    value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
    return key_cache, value_cache

write_to_paged_cache staticmethod

write_to_paged_cache(
    key: Tensor,
    value: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    slot_mapping: Tensor,
    kv_cache_dtype: str,
    k_scale: Tensor,
    v_scale: Tensor,
    *args,
) -> None
Source code in vllm/attention/ops/ipex_attn.py
@staticmethod
def write_to_paged_cache(
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
    kv_cache_dtype: str,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    *args,
) -> None:
    ops.reshape_and_cache(
        key,
        value,
        key_cache,
        value_cache,
        slot_mapping.flatten(),
        kv_cache_dtype,
        k_scale,
        v_scale,
    )