Skip to content

vllm.model_executor.layers.mamba.ops.gdn_chunk_cutedsl

Modules:

Name Description
kernel_h
kernel_kkt_inv_uw
kernel_o

chunk_gated_delta_rule_cutedsl

chunk_gated_delta_rule_cutedsl(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    g: Tensor,
    beta: Tensor,
    initial_state: Tensor,
    cu_seqlens: Tensor,
    chunk_indices: Tensor,
    chunk_offsets: Tensor,
    core_attn_out: Tensor | None = None,
) -> tuple[Tensor, Tensor]

Run the GDN chunk CuteDSL prefill kernels.

Parameters:

Name Type Description Default
q Tensor

Query tensor with shape [1, T, H, K].

required
k Tensor

Key tensor with shape [1, T, H, K].

required
v Tensor

Value tensor with shape [1, T, Hv, V].

required
g Tensor

Log-space decay tensor with shape [1, T, Hv].

required
beta Tensor

Delta-rule beta tensor with shape [1, T, Hv].

required
initial_state Tensor

Recurrent state with shape [N, Hv, V, K].

required
cu_seqlens Tensor

Cumulative sequence lengths with shape [N + 1].

required
chunk_indices Tensor

Chunk index metadata with shape [NT, 2].

required
chunk_offsets Tensor

Cumulative chunk offsets with shape [N + 1].

required
core_attn_out Tensor | None

Optional output buffer with shape [T, Hv, V].

None

Returns:

Type Description
Tensor

A tuple (output, final_state) where output has shape

Tensor

[1, T, Hv, V] and final_state has shape [N, Hv, V, K].

tuple[Tensor, Tensor]

When core_attn_out is provided, output is an unsqueezed view of

tuple[Tensor, Tensor]

that buffer.

Source code in vllm/model_executor/layers/mamba/ops/gdn_chunk_cutedsl/__init__.py
def chunk_gated_delta_rule_cutedsl(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    initial_state: torch.Tensor,
    cu_seqlens: torch.Tensor,
    chunk_indices: torch.Tensor,
    chunk_offsets: torch.Tensor,
    core_attn_out: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Run the GDN chunk CuteDSL prefill kernels.

    Args:
        q: Query tensor with shape ``[1, T, H, K]``.
        k: Key tensor with shape ``[1, T, H, K]``.
        v: Value tensor with shape ``[1, T, Hv, V]``.
        g: Log-space decay tensor with shape ``[1, T, Hv]``.
        beta: Delta-rule beta tensor with shape ``[1, T, Hv]``.
        initial_state: Recurrent state with shape ``[N, Hv, V, K]``.
        cu_seqlens: Cumulative sequence lengths with shape ``[N + 1]``.
        chunk_indices: Chunk index metadata with shape ``[NT, 2]``.
        chunk_offsets: Cumulative chunk offsets with shape ``[N + 1]``.
        core_attn_out: Optional output buffer with shape ``[T, Hv, V]``.

    Returns:
        A tuple ``(output, final_state)`` where ``output`` has shape
        ``[1, T, Hv, V]`` and ``final_state`` has shape ``[N, Hv, V, K]``.
        When ``core_attn_out`` is provided, ``output`` is an unsqueezed view of
        that buffer.
    """
    q_3d = q.squeeze(0)
    k_3d = k.squeeze(0)
    v_3d = v.squeeze(0)
    g_2d = g.squeeze(0)
    beta_2d = beta.squeeze(0)

    _, _, head_k_dim = k_3d.shape
    _, num_v_heads, head_v_dim = v_3d.shape
    chunk_size = 64
    upper_bound_chunks = chunk_indices.shape[0]
    pad_t = upper_bound_chunks * chunk_size
    total_chunks_ptr = chunk_offsets[-1:]

    g_cu = torch.empty_like(g_2d, dtype=torch.float32)
    u = q_3d.new_empty(pad_t, num_v_heads, head_v_dim)
    w = q_3d.new_empty(pad_t, num_v_heads, head_k_dim)

    num_sms = torch.cuda.get_device_properties(q.device).multi_processor_count
    kkt_inv_uw_cutedsl(
        k_3d,
        v_3d,
        u,
        w,
        g_2d,
        beta_2d,
        g_cu,
        cu_seqlens,
        chunk_indices,
        total_chunks_ptr,
        num_sms=num_sms,
    )

    h = k_3d.new_empty(
        upper_bound_chunks,
        num_v_heads,
        head_v_dim,
        head_k_dim,
    )
    v_new = q_3d.new_empty(pad_t, num_v_heads, head_v_dim)
    final_state = torch.empty_like(initial_state)
    h_cutedsl(
        k_3d,
        u,
        w,
        v_new,
        g_cu,
        h,
        initial_state,
        final_state,
        cu_seqlens,
        chunk_offsets,
    )

    output = core_attn_out if core_attn_out is not None else torch.empty_like(v_3d)
    scale = head_k_dim**-0.5
    o_cutedsl(
        q_3d,
        k_3d,
        v_new.view(upper_bound_chunks, chunk_size, num_v_heads, head_v_dim),
        h,
        g_cu,
        output,
        cu_seqlens,
        chunk_indices,
        total_chunks_ptr,
        scale,
        num_sms=num_sms,
    )
    return output.unsqueeze(0), final_state