def _state_passing_fwd(
states,
dA_chunk_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((batch, nchunks, nheads, dim),
device=states.device,
dtype=out_dtype)
final_states = torch.empty((batch, nheads, dim),
device=states.device,
dtype=torch.float32)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states,
out,
final_states,
dA_chunk_cumsum,
initial_states,
seq_idx,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size if seq_idx is not None else 0,
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_chunk_cumsum.stride(0),
dA_chunk_cumsum.stride(2),
dA_chunk_cumsum.stride(1),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2)) if initial_states is not None else
(0, 0, 0)),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_INITSTATES=initial_states is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_CONT_BATCHED=is_cont_batched,
)
return out, final_states