Skip to content

vllm.v1.attention.backends.mla.prefill.flashinfer

FlashInfer backend for MLA prefill.

FlashInferPrefillBackend

Bases: MLAPrefillBackend

FlashInfer backend for MLA prefill.

Source code in vllm/v1/attention/backends/mla/prefill/flashinfer.py
class FlashInferPrefillBackend(MLAPrefillBackend):
    """FlashInfer backend for MLA prefill."""

    requires_r1_mla_dimensions = True

    @staticmethod
    def get_name() -> str:
        return "FLASHINFER"

    @classmethod
    def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
        return device_capability.major == 10

    @classmethod
    def is_available(cls) -> bool:
        try:
            from flashinfer import (
                BatchPrefillWithRaggedKVCacheWrapper,  # noqa: F401
            )

            return True
        except ImportError:
            return False

    def __init__(
        self,
        num_heads: int,
        scale: float,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        vllm_config: "VllmConfig",
        device: torch.device,
        layer_names: list[str] | None = None,
    ) -> None:
        super().__init__(
            num_heads=num_heads,
            scale=scale,
            kv_lora_rank=kv_lora_rank,
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
            vllm_config=vllm_config,
            device=device,
            layer_names=layer_names,
        )

        self._prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
        self._prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
        if layer_names is None:
            raise ValueError(
                "FlashInferPrefillBackend requires layer_names to "
                "initialize global hyperparameters."
            )

        from vllm.model_executor.layers.attention.mla_attention import (
            MLACommonImpl,
        )

        self._global_hyperparameters = infer_global_hyperparameters(
            get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)  # type: ignore[type-abstract]
        )

    def _ensure_chunks(
        self,
        num_chunks: int,
        workspace_buffer: torch.Tensor,
    ) -> None:
        if len(self._prefill_chunks) < num_chunks:
            for _ in range(len(self._prefill_chunks), num_chunks):
                self._prefill_chunks.append(
                    BatchPrefillWithRaggedKVCacheWrapper(
                        workspace_buffer, "NHD", backend="cutlass"
                    )
                )

    def prepare_metadata(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
    ) -> None:
        qo_indptr = prefill_metadata.query_start_loc
        has_context = prefill_metadata.chunked_context is not None
        (workspace_buffer,) = current_workspace_manager().get_simultaneous(
            ((envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,), torch.uint8),
        )

        if self._prefill_main is None:
            self._prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
                workspace_buffer, "NHD", backend="cutlass"
            )
            self._ensure_chunks(_DEFAULT_NUM_CHUNKS, workspace_buffer)

        if has_context:
            chunked_context = prefill_metadata.chunked_context
            assert chunked_context is not None
            num_chunks = chunked_context.cu_seq_lens.shape[0]
            self._ensure_chunks(num_chunks, workspace_buffer)

        num_qo_heads = self.num_heads
        num_kv_heads = num_qo_heads

        head_dim_qk = self.qk_nope_head_dim + self.qk_rope_head_dim
        head_dim_vo = self.v_head_dim
        kv_indptr = qo_indptr.clone()

        assert self._prefill_main is not None
        self._prefill_main.plan(
            qo_indptr=qo_indptr,
            kv_indptr=kv_indptr,
            num_qo_heads=num_qo_heads,
            num_kv_heads=num_kv_heads,
            head_dim_qk=head_dim_qk,
            head_dim_vo=head_dim_vo,
            causal=True,
            sm_scale=self._global_hyperparameters.sm_scale,
            window_left=self._global_hyperparameters.window_left,
            logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
            q_data_type=prefill_metadata.q_data_type,
            o_data_type=prefill_metadata.output_dtype,
        )

        if has_context:
            chunked_context = prefill_metadata.chunked_context
            assert chunked_context is not None
            for i in range(num_chunks):
                kv_indptr_chunk = chunked_context.cu_seq_lens[i]

                self._prefill_chunks[i].plan(
                    qo_indptr=qo_indptr,
                    kv_indptr=kv_indptr_chunk,
                    num_qo_heads=num_qo_heads,
                    num_kv_heads=num_kv_heads,
                    head_dim_qk=head_dim_qk,
                    head_dim_vo=head_dim_vo,
                    causal=False,
                    sm_scale=self._global_hyperparameters.sm_scale,
                    window_left=self._global_hyperparameters.window_left,
                    logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
                    q_data_type=prefill_metadata.q_data_type,
                    o_data_type=prefill_metadata.output_dtype,
                )

    def run_prefill_new_tokens(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert self._prefill_main is not None

        ret = self._prefill_main.run(
            q=q,
            k=k,
            v=v,
            return_lse=return_softmax_lse,
        )

        if isinstance(ret, tuple):
            # Convert from (q_len, num_heads) to (num_heads, q_len)
            return ret[0], ret[1].transpose(0, 1).contiguous()
        return ret

    def run_prefill_context_chunk(
        self,
        chunk_idx: int,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        attn_out, lse = self._prefill_chunks[chunk_idx].run(
            q=q,
            k=k,
            v=v,
            return_lse=True,
        )

        # Convert from (q_len, num_heads) to (num_heads, q_len)
        return attn_out, lse.transpose(0, 1).contiguous()