Skip to content

vllm.v1.attention.backends.mla.prefill

Modules:

Name Description
base

Abstract base class for MLA prefill backends.

flash_attn

FlashAttention backend for MLA prefill.

flashinfer

FlashInfer backend for MLA prefill.

registry

Registry for MLA prefill backends.

selector

Selector for MLA prefill backends.

trtllm_ragged

TRT-LLM Ragged backend for MLA prefill.

MLAPrefillBackend

Bases: ABC

Abstract base class for MLA prefill backends.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
class MLAPrefillBackend(ABC):
    """Abstract base class for MLA prefill backends."""

    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16,
        torch.bfloat16,
    ]
    requires_r1_mla_dimensions: ClassVar[bool] = False

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

    @classmethod
    def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
        return True

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls.supported_dtypes

    @classmethod
    def is_available(cls) -> bool:
        return True

    @classmethod
    def validate_configuration(
        cls,
        device_capability: "DeviceCapability",
        selector_config: "MLAPrefillSelectorConfig",
    ) -> list[str]:
        invalid_reasons: list[str] = []

        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append(
                f"compute capability {device_capability.major}."
                f"{device_capability.minor} not supported"
            )

        if not cls.supports_dtype(selector_config.dtype):
            invalid_reasons.append(f"dtype {selector_config.dtype} not supported")

        if not cls.is_available():
            invalid_reasons.append("required dependencies not available")

        if cls.requires_r1_mla_dimensions and not selector_config.is_r1_compatible:
            invalid_reasons.append(
                "model does not have DeepSeek R1 MLA dimensions "
                "(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)"
            )

        return invalid_reasons

    def __init__(
        self,
        num_heads: int,
        scale: float,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        vllm_config: "VllmConfig",
        device: torch.device,
        layer_names: list[str] | None = None,
    ) -> None:
        self.num_heads = num_heads
        self.scale = scale
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.vllm_config = vllm_config
        self.device = device
        self.layer_names = layer_names

    def prepare_metadata(  # noqa: B027
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
    ) -> None:
        """Prepare backend-specific metadata before the forward pass.

        Called by the metadata builder after constructing the prefill metadata.
        """
        self._prefill_metadata = prefill_metadata

    @abstractmethod
    def run_prefill_new_tokens(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError

    @abstractmethod
    def run_prefill_context_chunk(
        self,
        chunk_idx: int,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError

prepare_metadata

prepare_metadata(
    prefill_metadata: MLACommonPrefillMetadata,
) -> None

Prepare backend-specific metadata before the forward pass.

Called by the metadata builder after constructing the prefill metadata.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
def prepare_metadata(  # noqa: B027
    self,
    prefill_metadata: "MLACommonPrefillMetadata",
) -> None:
    """Prepare backend-specific metadata before the forward pass.

    Called by the metadata builder after constructing the prefill metadata.
    """
    self._prefill_metadata = prefill_metadata

MLAPrefillBackendEnum

Bases: Enum

Enumeration of all supported MLA prefill backends.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
class MLAPrefillBackendEnum(Enum, metaclass=_MLAPrefillBackendEnumMeta):
    """Enumeration of all supported MLA prefill backends."""

    FLASH_ATTN = (
        "vllm.v1.attention.backends.mla.prefill.flash_attn.FlashAttnPrefillBackend"
    )
    FLASHINFER = (
        "vllm.v1.attention.backends.mla.prefill.flashinfer.FlashInferPrefillBackend"
    )
    TRTLLM_RAGGED = (
        "vllm.v1.attention.backends.mla.prefill.trtllm_ragged."
        "TrtllmRaggedPrefillBackend"
    )

    def get_path(self) -> str:
        """Get the fully qualified class path for this backend."""
        return self.value

    def get_class(self) -> "type[MLAPrefillBackend]":
        """Lazy load and return the backend class."""
        return resolve_obj_by_qualname(self.get_path())

get_class

get_class() -> type[MLAPrefillBackend]

Lazy load and return the backend class.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def get_class(self) -> "type[MLAPrefillBackend]":
    """Lazy load and return the backend class."""
    return resolve_obj_by_qualname(self.get_path())

get_path

get_path() -> str

Get the fully qualified class path for this backend.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def get_path(self) -> str:
    """Get the fully qualified class path for this backend."""
    return self.value

get_mla_prefill_backend

get_mla_prefill_backend(
    vllm_config: VllmConfig,
) -> type[MLAPrefillBackend]

Select the MLA prefill backend based on configuration and device.

This function first checks for explicit user preferences via mla_prefill_backend in AttentionConfig, then falls back to automatic priority-based selection.

Parameters:

Name Type Description Default
vllm_config VllmConfig

The vLLM configuration.

required

Returns:

Type Description
type[MLAPrefillBackend]

The selected prefill backend class.

Source code in vllm/v1/attention/backends/mla/prefill/selector.py
def get_mla_prefill_backend(
    vllm_config: "VllmConfig",
) -> "type[MLAPrefillBackend]":
    """Select the MLA prefill backend based on configuration and device.

    This function first checks for explicit user preferences via
    mla_prefill_backend in AttentionConfig, then falls back to automatic
    priority-based selection.

    Args:
        vllm_config: The vLLM configuration.

    Returns:
        The selected prefill backend class.
    """
    from vllm.platforms import current_platform

    device_capability = current_platform.get_device_capability()
    if device_capability is None:
        logger.info_once(
            "Device capability not available, using FlashAttention MLA prefill backend."
        )
        return MLAPrefillBackendEnum.FLASH_ATTN.get_class()

    attention_config = vllm_config.attention_config

    selector_config = MLAPrefillSelectorConfig(
        dtype=vllm_config.model_config.dtype,
        is_r1_compatible=is_deepseek_r1_mla_compatible(vllm_config),
    )

    if attention_config.mla_prefill_backend is not None:
        selected_backend = attention_config.mla_prefill_backend
        backend_cls: type[MLAPrefillBackend] | None = None
        try:
            backend_cls = selected_backend.get_class()
            invalid_reasons = backend_cls.validate_configuration(
                device_capability, selector_config
            )
        except ImportError:
            invalid_reasons = ["ImportError"]
        if invalid_reasons:
            raise ValueError(
                f"Selected MLA prefill backend {selected_backend.name} "
                f"is not valid for this configuration. "
                f"Reason: {invalid_reasons}"
            )
        assert backend_cls is not None
        logger.info("Using %s MLA prefill backend.", selected_backend.name)
        return backend_cls

    return _auto_select_mla_prefill_backend(
        device_capability,
        selector_config,
    )