@triton.jit
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
segm_idx = tl.program_id(2)
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs,
BLOCK_Q, True)
q_block_start_idx = tl.load(query_start_len_ptr +
seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index \
- cur_batch_in_all_start_index
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# number of segments for this particular sequence
num_segments = NUM_SEGMENTS_PER_SEQ
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
return
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + \
offs_m % num_queries_per_kv
query_offset = (query_offset_0[:, None] * query_stride_0 +
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
Q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
# context length for this particular sequences
context_len = seq_len - cur_batch_query_len
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
mask=query_mask_1,
other=0.0)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# iterate through tiles within current segment
for j in range(
segm_idx * blocks_per_segment,
min((segm_idx + 1) * blocks_per_segment, num_blocks),
):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
offs_n = tl.arange(0, BLOCK_SIZE)
v_offset = (physical_block_idx * stride_v_cache_0 +
kv_head_idx * stride_v_cache_2 +
offs_d[None, :] * stride_v_cache_3 +
offs_n[:, None] * stride_v_cache_1)
k_offset = (physical_block_idx * stride_k_cache_0 +
kv_head_idx * stride_k_cache_2 +
offs_d[:, None] * stride_k_cache_3 +
offs_n[None, :] * stride_k_cache_1)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(key_cache_ptr + k_offset,
mask=dim_mask[:, None],
other=0.0)
if K_load.dtype.is_fp8():
if Q.dtype.is_fp8():
K = K_load
else:
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[None, :],
other=0.0)
if V_load.dtype.is_fp8():
if Q.dtype.is_fp8():
V = V_load
else:
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
seq_offset = j * BLOCK_SIZE + offs_n
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# S : (BLOCK_M, BLOCK_SIZE)
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K)
if USE_SOFTCAP:
S = apply_softcap(S, softcap)
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
S, float("-inf"))
if SLIDING_WINDOW > 0:
S = tl.where((context_len + query_pos[:, None] - seq_offset)
< SLIDING_WINDOW, S, float("-inf"))
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
# compute running maximum
# m_j : (BLOCK_M,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
# P : (BLOCK_M, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None])
# l_j : (BLOCK_M,)
l_j = tl.sum(P, axis=1)
# alpha : (BLOCK_M, )
alpha = tl.exp(M - m_j)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
segm_output_offset = (
query_offset_0[:, None].to(tl.int64) *
(num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) +
segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :])
tl.store(
segm_output_ptr + segm_output_offset,
acc,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
)
segm_offset = (query_offset_0.to(tl.int64) *
(num_query_heads * NUM_SEGMENTS_PER_SEQ) +
query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx)
tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1)
tl.store(segm_expsum_ptr + segm_offset,
L,
mask=query_mask_0 & query_mask_1)