vllm.model_executor.layers.fla.ops.solve_tril ¶
solve_tril ¶
solve_tril(
A: Tensor,
cu_seqlens: Tensor | None = None,
chunk_indices: Tensor | None = None,
output_dtype: dtype = float,
) -> Tensor
Compute the inverse of the matrix I + A A should be strictly lower triangular, i.e., A.triu() == 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A | Tensor | [B, T, H, BT], where BT should only be 16, 32, or 64. | required |
cu_seqlens | Tensor | The cumulative sequence lengths of the input tensor. Default: | None |
chunk_indices | Tensor | Pre-computed chunk indices. Default: | None |
output_dtype | dtype | The dtype of the output tensor. Default: | float |
Returns:
| Type | Description |
|---|---|
Tensor | (I + A)^-1 with the same shape as A |