Skip to content

vllm.attention.ops.blocksparse_attention.interface

IS_COMPUTE_8_OR_ABOVE module-attribute

IS_COMPUTE_8_OR_ABOVE = has_device_capability(80)

LocalStridedBlockSparseAttn

Bases: Module

Source code in vllm/attention/ops/blocksparse_attention/interface.py
class LocalStridedBlockSparseAttn(torch.nn.Module):

    def __init__(
        self,
        n_heads,
        max_seqlen,
        local_blocks,
        vert_stride,
        block_size,
        device=None,
        dtype=None,
        homo_head=False,
        active_head_range=None,
        q_block_size=None,
        use_spda=None,
    ):
        super().__init__()
        if use_spda is None:
            use_spda = current_platform.is_rocm() or \
                        current_platform.is_cpu() or not \
                        IS_COMPUTE_8_OR_ABOVE
        device = device or (torch.cuda.current_device()
                            if current_platform.is_cuda_alike() else "cpu")
        device = torch.device(device)
        # NOTE: vllm CPU backend support BF16 instead of FP16.
        dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
                          or device.type == "cpu" else torch.half)

        self.n_heads = n_heads
        self.max_seqlen = max_seqlen
        self.local_blocks = local_blocks
        self.vert_stride = vert_stride
        self.use_spda = use_spda
        self.dtype = dtype
        self.device = device
        self.block_size = block_size
        self.q_block_size = q_block_size
        self.homo_head = homo_head
        self.active_head_range = active_head_range
        self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
                                                       homo_head)

        sparse_layout, sparse_pattern, self.dense_attn_mask = (
            self.get_attn_pattern(dtype, device))

        if q_block_size is not None and q_block_size != block_size:
            if q_block_size > block_size:
                assert q_block_size % block_size == 0
                blocks_to_merge = q_block_size // block_size
                shape = sparse_pattern.shape
                sparse_pattern = sparse_pattern.view(shape[0], -1,
                                                     blocks_to_merge,
                                                     shape[-1])
                sparse_pattern = sparse_pattern.sum(2)
                sparse_layout = dense_to_crow_col(sparse_pattern)
            else:
                raise ValueError(
                    "Does not support smaller q_block_size. It will be slower."
                )

        self.sparse_layout = sparse_layout

    def get_attn_pattern(self, dtype, device):
        sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
            self.n_heads,
            self.max_seqlen,
            self.max_seqlen,
            dtype,
            device,
            block_size=self.block_size,
            local_blocks=self.local_blocks,
            vert_stride=self.vert_stride,
            homo_head=self.homo_head,
            return_dense=self.use_spda,
            dense_mask_type="bias",
        )
        if (not self.homo_head) and (self.active_head_range is not None):
            assert isinstance(self.active_head_range, tuple)
            assert (len(self.active_head_range) == 2)
            h_start, h_end = self.active_head_range
            sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
            if self.use_spda:
                dense_attn_mask = dense_attn_mask[h_start:h_end]
        return sparse_layout, sparse_pattern, dense_attn_mask

    def varlen_attn(self,
                    q,
                    k,
                    v,
                    cu_seqlens_k,
                    cu_seqlens_q=None,
                    sm_scale=None):
        """
        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
        Support grouped attention, with `q[:, i*r:(i*r + r)]`
        is correspondent to `k[:, i]`, where `r` is the q/k ratio.
        cu_seqlens_k: shape=(batch_size + 1,),
        indicating segment of samples,
        e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
        cu_seqlens_q: shape=(batch_size + 1, ).
        Default None: same as cu_seqlens_k for prefilling or
        [0, 1, .., batch_size] for decoding.
        The only case you need to specify is when q is a mix of
        prefilling and decoding.
        sm_scale: softmax scale, default to 1/sqrt(head_size).

        return: tensor of shape as q.
        """
        assert (
            IS_COMPUTE_8_OR_ABOVE
        ), "Requires compute capability of 8 or above (Ampere or newer) to use \
            Triton kernel."

        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))

        return blocksparse_flash_attn_varlen_fwd(
            q,
            k,
            v,
            cu_seqlens_k,
            cu_seqlens_q,
            sm_scale,
            self.sparse_layout,
            block_size=self.block_size,
            q_block_size=self.q_block_size,
            max_seqlen=self.max_seqlen,
        )

    @staticmethod
    def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
        """
        :param x: (total_tokens, n_heads, head_size)
        :return: (batch, n_heads, length, head_size)
        """
        x_padded = x.new_empty(
            len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
        cu_seqlens = cu_seqlens.cpu()
        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
            x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
                                                             1).unsqueeze(1))
        return x_padded.flatten(1, 2)

    @staticmethod
    def transpose_and_unpad(x_padded, cu_seqlens):
        """
        :param x_padded: (batch, n_heads, length, head_size)
        :return: (total_tokens, n_heads, head_size)
        """
        cu_seqlens = cu_seqlens.cpu()
        total_n_tokens = cu_seqlens[-1]
        x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
                               x_padded.size(3))
        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
            x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
        return x

    def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
        """For CPU, V100 or other older GPUs.
        NOTE: torch SPDA supports nested tensor,
        but seems extremely slow. Choose to pad instead.
        """
        assert (cu_seqlens_q is None or
                (cu_seqlens_q
                 == cu_seqlens_k).all()), "Can only handle prompt with SPDA."
        assert q.size(0) == k.size(0), "can only handle prompt with SPDA."

        assert q.size(1) % k.size(1) == 0
        q_k_ratio = q.size(1) // k.size(1)
        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
        cu_seqlens = cu_seqlens_k.cpu()
        maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

        if (self.dense_attn_mask.dtype != q.dtype
                or self.dense_attn_mask.device != q.device):
            _, _, self.dense_attn_mask = self.get_attn_pattern(
                q.dtype, q.device)
        attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]

        q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
        k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
                  for x in [k, v])
        spda_output = torch.nn.functional.scaled_dot_product_attention(
            q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
        return self.transpose_and_unpad(spda_output, cu_seqlens)

    def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
        """Dispatch to `varlen_attn` (Ampere or newer) or
        `self.spda`(cpu, Volta, Turing or older)based on
        the type of device used and cuda compute capability.

        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
                Support grouped attention, with `q[:, i*r:(i*r + r)]`
                is correspondent to `k[:, i]`, where `r` is the q/k ratio.
        cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
                    e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
        cu_seqlens_q: shape=(batch_size + 1, ).
                    Default None: same as cu_seqlens_k for prefilling or
                    [0, 1, .., batch_size] for decoding.
                    The only case you need to specify
                    is when q is a mix of prefilling
                    and decoding.
        sm_scale: softmax scale, default to 1/sqrt(head_size).

        return: tensor of shape as q.
        """
        assert k.dim() == 3
        if self.use_spda:
            return self.spda(
                q,
                k,
                v,
                cu_seqlens_k,
                cu_seqlens_q=cu_seqlens_q,
                sm_scale=sm_scale,
            )
        return self.varlen_attn(q,
                                k,
                                v,
                                cu_seqlens_k,
                                cu_seqlens_q=cu_seqlens_q,
                                sm_scale=sm_scale)

active_head_range instance-attribute

active_head_range = active_head_range

block_size instance-attribute

block_size = block_size

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

head_sliding_step instance-attribute

head_sliding_step = get_head_sliding_step(
    n_heads, vert_stride, homo_head
)

homo_head instance-attribute

homo_head = homo_head

local_blocks instance-attribute

local_blocks = local_blocks

max_seqlen instance-attribute

max_seqlen = max_seqlen

n_heads instance-attribute

n_heads = n_heads

q_block_size instance-attribute

q_block_size = q_block_size

sparse_layout instance-attribute

sparse_layout = sparse_layout

use_spda instance-attribute

use_spda = use_spda

vert_stride instance-attribute

vert_stride = vert_stride

__init__

__init__(
    n_heads,
    max_seqlen,
    local_blocks,
    vert_stride,
    block_size,
    device=None,
    dtype=None,
    homo_head=False,
    active_head_range=None,
    q_block_size=None,
    use_spda=None,
)
Source code in vllm/attention/ops/blocksparse_attention/interface.py
def __init__(
    self,
    n_heads,
    max_seqlen,
    local_blocks,
    vert_stride,
    block_size,
    device=None,
    dtype=None,
    homo_head=False,
    active_head_range=None,
    q_block_size=None,
    use_spda=None,
):
    super().__init__()
    if use_spda is None:
        use_spda = current_platform.is_rocm() or \
                    current_platform.is_cpu() or not \
                    IS_COMPUTE_8_OR_ABOVE
    device = device or (torch.cuda.current_device()
                        if current_platform.is_cuda_alike() else "cpu")
    device = torch.device(device)
    # NOTE: vllm CPU backend support BF16 instead of FP16.
    dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
                      or device.type == "cpu" else torch.half)

    self.n_heads = n_heads
    self.max_seqlen = max_seqlen
    self.local_blocks = local_blocks
    self.vert_stride = vert_stride
    self.use_spda = use_spda
    self.dtype = dtype
    self.device = device
    self.block_size = block_size
    self.q_block_size = q_block_size
    self.homo_head = homo_head
    self.active_head_range = active_head_range
    self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
                                                   homo_head)

    sparse_layout, sparse_pattern, self.dense_attn_mask = (
        self.get_attn_pattern(dtype, device))

    if q_block_size is not None and q_block_size != block_size:
        if q_block_size > block_size:
            assert q_block_size % block_size == 0
            blocks_to_merge = q_block_size // block_size
            shape = sparse_pattern.shape
            sparse_pattern = sparse_pattern.view(shape[0], -1,
                                                 blocks_to_merge,
                                                 shape[-1])
            sparse_pattern = sparse_pattern.sum(2)
            sparse_layout = dense_to_crow_col(sparse_pattern)
        else:
            raise ValueError(
                "Does not support smaller q_block_size. It will be slower."
            )

    self.sparse_layout = sparse_layout

forward

forward(
    q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None
)

Dispatch to varlen_attn (Ampere or newer) or self.spda(cpu, Volta, Turing or older)based on the type of device used and cuda compute capability.

q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). Support grouped attention, with q[:, i*r:(i*r + r)] is correspondent to k[:, i], where r is the q/k ratio. cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, e.g., k[cu_seqlen[i]:cu_seqlne[i+1]] is q of sample i cu_seqlens_q: shape=(batch_size + 1, ). Default None: same as cu_seqlens_k for prefilling or [0, 1, .., batch_size] for decoding. The only case you need to specify is when q is a mix of prefilling and decoding. sm_scale: softmax scale, default to 1/sqrt(head_size).

return: tensor of shape as q.

Source code in vllm/attention/ops/blocksparse_attention/interface.py
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
    """Dispatch to `varlen_attn` (Ampere or newer) or
    `self.spda`(cpu, Volta, Turing or older)based on
    the type of device used and cuda compute capability.

    q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
            Support grouped attention, with `q[:, i*r:(i*r + r)]`
            is correspondent to `k[:, i]`, where `r` is the q/k ratio.
    cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
                e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
    cu_seqlens_q: shape=(batch_size + 1, ).
                Default None: same as cu_seqlens_k for prefilling or
                [0, 1, .., batch_size] for decoding.
                The only case you need to specify
                is when q is a mix of prefilling
                and decoding.
    sm_scale: softmax scale, default to 1/sqrt(head_size).

    return: tensor of shape as q.
    """
    assert k.dim() == 3
    if self.use_spda:
        return self.spda(
            q,
            k,
            v,
            cu_seqlens_k,
            cu_seqlens_q=cu_seqlens_q,
            sm_scale=sm_scale,
        )
    return self.varlen_attn(q,
                            k,
                            v,
                            cu_seqlens_k,
                            cu_seqlens_q=cu_seqlens_q,
                            sm_scale=sm_scale)

get_attn_pattern

get_attn_pattern(dtype, device)
Source code in vllm/attention/ops/blocksparse_attention/interface.py
def get_attn_pattern(self, dtype, device):
    sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
        self.n_heads,
        self.max_seqlen,
        self.max_seqlen,
        dtype,
        device,
        block_size=self.block_size,
        local_blocks=self.local_blocks,
        vert_stride=self.vert_stride,
        homo_head=self.homo_head,
        return_dense=self.use_spda,
        dense_mask_type="bias",
    )
    if (not self.homo_head) and (self.active_head_range is not None):
        assert isinstance(self.active_head_range, tuple)
        assert (len(self.active_head_range) == 2)
        h_start, h_end = self.active_head_range
        sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
        if self.use_spda:
            dense_attn_mask = dense_attn_mask[h_start:h_end]
    return sparse_layout, sparse_pattern, dense_attn_mask

spda

spda(
    q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None
)

For CPU, V100 or other older GPUs. NOTE: torch SPDA supports nested tensor, but seems extremely slow. Choose to pad instead.

Source code in vllm/attention/ops/blocksparse_attention/interface.py
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
    """For CPU, V100 or other older GPUs.
    NOTE: torch SPDA supports nested tensor,
    but seems extremely slow. Choose to pad instead.
    """
    assert (cu_seqlens_q is None or
            (cu_seqlens_q
             == cu_seqlens_k).all()), "Can only handle prompt with SPDA."
    assert q.size(0) == k.size(0), "can only handle prompt with SPDA."

    assert q.size(1) % k.size(1) == 0
    q_k_ratio = q.size(1) // k.size(1)
    sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
    cu_seqlens = cu_seqlens_k.cpu()
    maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

    if (self.dense_attn_mask.dtype != q.dtype
            or self.dense_attn_mask.device != q.device):
        _, _, self.dense_attn_mask = self.get_attn_pattern(
            q.dtype, q.device)
    attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]

    q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
    k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
              for x in [k, v])
    spda_output = torch.nn.functional.scaled_dot_product_attention(
        q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
    return self.transpose_and_unpad(spda_output, cu_seqlens)

transpose_and_pad staticmethod

transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1)

:param x: (total_tokens, n_heads, head_size) :return: (batch, n_heads, length, head_size)

Source code in vllm/attention/ops/blocksparse_attention/interface.py
@staticmethod
def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
    """
    :param x: (total_tokens, n_heads, head_size)
    :return: (batch, n_heads, length, head_size)
    """
    x_padded = x.new_empty(
        len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
    cu_seqlens = cu_seqlens.cpu()
    for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
        x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
                                                         1).unsqueeze(1))
    return x_padded.flatten(1, 2)

transpose_and_unpad staticmethod

transpose_and_unpad(x_padded, cu_seqlens)

:param x_padded: (batch, n_heads, length, head_size) :return: (total_tokens, n_heads, head_size)

Source code in vllm/attention/ops/blocksparse_attention/interface.py
@staticmethod
def transpose_and_unpad(x_padded, cu_seqlens):
    """
    :param x_padded: (batch, n_heads, length, head_size)
    :return: (total_tokens, n_heads, head_size)
    """
    cu_seqlens = cu_seqlens.cpu()
    total_n_tokens = cu_seqlens[-1]
    x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
                           x_padded.size(3))
    for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
        x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
    return x

varlen_attn

varlen_attn(
    q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None
)

q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). Support grouped attention, with q[:, i*r:(i*r + r)] is correspondent to k[:, i], where r is the q/k ratio. cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, e.g., k[cu_seqlen[i]:cu_seqlne[i+1]] is q of sample i cu_seqlens_q: shape=(batch_size + 1, ). Default None: same as cu_seqlens_k for prefilling or [0, 1, .., batch_size] for decoding. The only case you need to specify is when q is a mix of prefilling and decoding. sm_scale: softmax scale, default to 1/sqrt(head_size).

return: tensor of shape as q.

Source code in vllm/attention/ops/blocksparse_attention/interface.py
def varlen_attn(self,
                q,
                k,
                v,
                cu_seqlens_k,
                cu_seqlens_q=None,
                sm_scale=None):
    """
    q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
    Support grouped attention, with `q[:, i*r:(i*r + r)]`
    is correspondent to `k[:, i]`, where `r` is the q/k ratio.
    cu_seqlens_k: shape=(batch_size + 1,),
    indicating segment of samples,
    e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
    cu_seqlens_q: shape=(batch_size + 1, ).
    Default None: same as cu_seqlens_k for prefilling or
    [0, 1, .., batch_size] for decoding.
    The only case you need to specify is when q is a mix of
    prefilling and decoding.
    sm_scale: softmax scale, default to 1/sqrt(head_size).

    return: tensor of shape as q.
    """
    assert (
        IS_COMPUTE_8_OR_ABOVE
    ), "Requires compute capability of 8 or above (Ampere or newer) to use \
        Triton kernel."

    sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))

    return blocksparse_flash_attn_varlen_fwd(
        q,
        k,
        v,
        cu_seqlens_k,
        cu_seqlens_q,
        sm_scale,
        self.sparse_layout,
        block_size=self.block_size,
        q_block_size=self.q_block_size,
        max_seqlen=self.max_seqlen,
    )