Skip to content

vllm.model_executor.models.extract_hidden_states

Hidden States Extractor Model.

This model extracts and caches hidden states from the target model without performing actual token generation. It's used with the extract_hidden_states speculative decoding method.

CacheOnlyAttentionBackend

Bases: AttentionBackend

Attention backend that only caches KV without computing attention.

Source code in vllm/model_executor/models/extract_hidden_states.py
class CacheOnlyAttentionBackend(AttentionBackend):
    """Attention backend that only caches KV without computing attention."""

    accept_output_buffer: bool = False
    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16,
        torch.bfloat16,
        torch.float32,
    ]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto",
        "bfloat16",
    ]
    forward_includes_kv_cache_update: bool = False

    @staticmethod
    def get_name() -> str:
        return "CACHE_ONLY_ATTN"

    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        return attn_type == AttentionType.DECODER

    @classmethod
    def supports_mm_prefix(cls) -> bool:
        return True

    @staticmethod
    def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
        return CacheOnlyAttentionImpl

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        # We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
        # We also don't use a k/v (2) dim
        return (num_blocks, block_size, num_kv_heads, head_size)

    @staticmethod
    def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
        return CacheOnlyAttentionMetadataBuilder

    @staticmethod
    def use_cascade_attention(*args, **kwargs) -> bool:
        return False

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return []

CacheOnlyAttentionImpl

Bases: AttentionImpl

Attention implementation that only caches KV states.

Source code in vllm/model_executor/models/extract_hidden_states.py
class CacheOnlyAttentionImpl(AttentionImpl):
    """Attention implementation that only caches KV states."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        kv_cache_dtype: str,
        kv_cache_torch_dtype: torch.dtype,
        attn_type: AttentionType = AttentionType.DECODER,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.kv_cache_dtype = kv_cache_dtype
        self.kv_cache_torch_dtype = kv_cache_torch_dtype

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError(f"Unsupported attention type: {attn_type}")
        if is_quantized_kv_cache(kv_cache_dtype):
            raise NotImplementedError("Quantized KV cache not supported")

        self.num_queries_per_kv = 1

    def do_kv_cache_update(
        self,
        layer,
        to_cache,
        kv_cache,
        slot_mapping,
    ):
        assert to_cache.dtype == self.kv_cache_torch_dtype, (
            f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
        )
        assert kv_cache.dtype == self.kv_cache_torch_dtype, (
            f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
        )

        basic_cache(to_cache, kv_cache, slot_mapping)

    def forward(self, *args, **kwargs):
        # Empty implementation of abstract method
        pass

CacheOnlyAttentionLayer

Bases: Module, AttentionLayerBase

Attention layer that only caches key/value states without computing attention.

Source code in vllm/model_executor/models/extract_hidden_states.py
class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
    """Attention layer that only caches key/value states without computing attention."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        cache_config: CacheConfig | None = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ):
        super().__init__()

        self.num_heads = num_heads
        self.head_size = head_size
        self.layer_name = prefix

        vllm_config = get_current_vllm_config()

        # KV cache configuration
        cache_config = cache_config or vllm_config.cache_config
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            self.block_size = cache_config.block_size
        else:
            kv_cache_dtype = "auto"
            self.block_size = 16

        assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
            "CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
            f"kv cache dtype was set to {kv_cache_dtype}"
        )
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )

        # Initialize KV cache quantization attributes
        set_default_quant_scales(self, register_buffer=True)

        # Attention backend
        self.attn_backend = CacheOnlyAttentionBackend
        impl_cls = self.attn_backend.get_impl_cls()
        self.impl = impl_cls(
            num_heads,
            head_size,
            kv_cache_dtype,
            self.kv_cache_torch_dtype,
            attn_type,
        )

        assert not self.attn_backend.forward_includes_kv_cache_update, (
            "KV cache update should be independent of forward"
        )

        # Placeholder KV cache (replaced by bind_kv_cache)
        self.kv_cache = [
            torch.tensor([])
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
        ]

        # Register in compilation context
        compilation_config = vllm_config.compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

    def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
        """Cache hidden states as KV pairs without computing attention.

        Args:
            to_cache: The tensor to insert into the kv cache.
                shape [num_tokens, num_heads, head_size]

        Returns:
            Dummy output tensor (not used)
        """
        # Note: we set num_heads to num_hidden_layers and
        # head_size to hidden_size for hidden states storage
        output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)

        # Note: dummy_out is used to force torch.compile to preserve ordering between
        # cache update and attention op (which triggers kv_connector transfer)
        dummy_out = unified_kv_cache_update(to_cache, self.layer_name)

        # Triggers kv_connector transfer via decorator
        _ = dummy_attention(self.layer_name, dummy_out)

        return output

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Note: we use MLAAttentionSpec here to because it will
        # produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
        # whereas FullAttentionSpec will add an additional factor of 2
        return MLAAttentionSpec(
            block_size=self.block_size,
            num_kv_heads=self.num_heads,
            head_size=self.head_size,
            dtype=self.kv_cache_torch_dtype,
        )

forward

forward(to_cache: Tensor) -> Tensor

Cache hidden states as KV pairs without computing attention.

Parameters:

Name Type Description Default
to_cache Tensor

The tensor to insert into the kv cache. shape [num_tokens, num_heads, head_size]

required

Returns:

Type Description
Tensor

Dummy output tensor (not used)

Source code in vllm/model_executor/models/extract_hidden_states.py
def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
    """Cache hidden states as KV pairs without computing attention.

    Args:
        to_cache: The tensor to insert into the kv cache.
            shape [num_tokens, num_heads, head_size]

    Returns:
        Dummy output tensor (not used)
    """
    # Note: we set num_heads to num_hidden_layers and
    # head_size to hidden_size for hidden states storage
    output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)

    # Note: dummy_out is used to force torch.compile to preserve ordering between
    # cache update and attention op (which triggers kv_connector transfer)
    dummy_out = unified_kv_cache_update(to_cache, self.layer_name)

    # Triggers kv_connector transfer via decorator
    _ = dummy_attention(self.layer_name, dummy_out)

    return output

ExtractHiddenStatesModel

Bases: Module

Source code in vllm/model_executor/models/extract_hidden_states.py
class ExtractHiddenStatesModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.vllm_config = vllm_config
        self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
        self.hidden_size = vllm_config.model_config.get_hidden_size()
        self.target_num_hidden_layers = (
            vllm_config.model_config.get_total_num_hidden_layers()
        )
        self.num_hidden_states = len(
            getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
        )

        cache_config = vllm_config.cache_config

        # Create a single cache-only attention layer
        # Note: We set num_heads <- self.num_hidden_states
        # and head_size <- hidden_size so that we can insert
        # the hidden states directly into the cache without
        # reshaping
        self.cache_only_layers = nn.ModuleDict(
            {
                str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
                    num_heads=self.num_hidden_states,
                    head_size=self.hidden_size,
                    cache_config=cache_config,
                    prefix=maybe_prefix(
                        prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
                    ),
                )
            }
        )

    def forward(self, hidden_states: torch.Tensor) -> None:
        """Process and cache hidden states.

        Args:
            hidden_states: Hidden states from target model
                          shape: [num_tokens, num_hidden_states, hidden_size]

        Returns:
            Tuple of (dummy_output, dummy_output) - both unused
        """

        # Call dummy attention layer to cache hidden states
        # Output is ignored - we only care about the KV cache side effects
        _ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """No weights to load for this dummy model."""
        return set()

forward

forward(hidden_states: Tensor) -> None

Process and cache hidden states.

Parameters:

Name Type Description Default
hidden_states Tensor

Hidden states from target model shape: [num_tokens, num_hidden_states, hidden_size]

required

Returns:

Type Description
None

Tuple of (dummy_output, dummy_output) - both unused

Source code in vllm/model_executor/models/extract_hidden_states.py
def forward(self, hidden_states: torch.Tensor) -> None:
    """Process and cache hidden states.

    Args:
        hidden_states: Hidden states from target model
                      shape: [num_tokens, num_hidden_states, hidden_size]

    Returns:
        Tuple of (dummy_output, dummy_output) - both unused
    """

    # Call dummy attention layer to cache hidden states
    # Output is ignored - we only care about the KV cache side effects
    _ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

No weights to load for this dummy model.

Source code in vllm/model_executor/models/extract_hidden_states.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    """No weights to load for this dummy model."""
    return set()

unified_kv_cache_update

unified_kv_cache_update(
    to_cache: Tensor, layer_name: str
) -> Tensor

Returns a dummy that is passed to unified_attention to signal a side effect and the data dependency between them to ensure torch.compile preserves ordering.

Source code in vllm/model_executor/models/extract_hidden_states.py
def unified_kv_cache_update(
    to_cache: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    """
    Returns a dummy that is passed to unified_attention to signal a side effect and
    the data dependency between them to ensure torch.compile preserves ordering.
    """
    forward_context = get_forward_context()
    attn_layer = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]

    slot_mapping = forward_context.slot_mapping
    assert isinstance(slot_mapping, dict), (
        f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
    )
    layer_slot_mapping = slot_mapping.get(layer_name)
    if layer_slot_mapping is not None:
        assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
            f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
        )
        attn_layer.impl.do_kv_cache_update(
            attn_layer,
            to_cache,
            kv_cache,
            layer_slot_mapping,
        )

    return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)