Skip to content

vllm.v1.spec_decode.extract_hidden_states

ExtractHiddenStatesProposer

Source code in vllm/v1/spec_decode/extract_hidden_states.py
class ExtractHiddenStatesProposer:
    def __init__(self, vllm_config: VllmConfig, device):
        assert vllm_config.speculative_config is not None

        assert vllm_config.speculative_config.num_speculative_tokens == 1
        if vllm_config.speculative_config.disable_padded_drafter_batch:
            raise ValueError(
                "disable_padded_drafter_batch is not supported with "
                "extract_hidden_states method"
            )
        self.vllm_config = vllm_config
        self.device = device
        self.dtype = vllm_config.model_config.dtype
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank

        # Model and attention layer tracking (initialized in load_model)
        self.model: nn.Module | None = None
        self.attn_layer_names: list[str] = []
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None

        # Maximum number of tokens for buffers
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
        self.max_num_tokens = (
            vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
        )

        self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
        layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
        if not layer_ids:
            raise ValueError(
                "eagle_aux_hidden_state_layer_ids must be set in the draft "
                "model config for extract_hidden_states method"
            )
        self.num_hidden_states = len(layer_ids)
        self.hidden_size = vllm_config.model_config.get_hidden_size()
        self.hidden_states = torch.zeros(
            (self.max_num_tokens, self.num_hidden_states, self.hidden_size),
            dtype=self.dtype,
            device=device,
        )
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)

        self._slot_mapping_buffer = torch.zeros(
            self.max_num_tokens, dtype=torch.int64, device=device
        )

    def propose(
        self,
        sampled_token_ids: torch.Tensor,
        target_hidden_states: list[torch.Tensor],
        common_attn_metadata: CommonAttentionMetadata,
        scheduler_output: SchedulerOutput,
        slot_mappings: dict[str, torch.Tensor]
        | list[dict[str, torch.Tensor]]
        | None = None,
    ) -> tuple[torch.Tensor, KVConnectorOutput | None]:
        """Propose draft tokens by calling the ExtractHiddenStatesModel model.

        The ExtractHiddenStatesModel caches the hidden states in the KV cache
        without performing actual attention computation. This allows us to
        extract and store hidden states for later use (e.g., KV transfer).

        This proposer doesn't actually perform speculation - it returns the
        sampled tokens as "draft" tokens, ensuring they always verify (match).
        The main purpose is to cache hidden states, not to speculate.

        Args:
            sampled_token_ids: Sampled token IDs from the target model
            target_hidden_states: List of hidden state tensors from target model
                                (one per aux hidden state layer)
            common_attn_metadata: Attention metadata
            scheduler_output: Scheduler output for KV connector
            slot_mappings: Slot mappings for KV cache (unused, provided for
                          interface compatibility)

        Returns:
            Tuple of:
                - Draft tokens matching sampled tokens, shape [batch_size, 1]
                - KV connector output (if KV transfer is active), else None
        """
        assert self.model is not None and isinstance(target_hidden_states, list)

        # target_hidden_states is a list of tensors (one per layer)
        # Each tensor has shape [num_tokens, hidden_size]
        # Stack to shape: [num_tokens, num_hidden_states, hidden_size]
        stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
        num_tokens = stacked_hidden_states.shape[0]

        # Copy hidden states to buffer
        self.hidden_states[:num_tokens] = stacked_hidden_states

        assert self.attn_metadata_builder is not None
        attn_metadata = self.attn_metadata_builder.build_for_drafting(
            common_attn_metadata=common_attn_metadata, draft_index=0
        )

        # We assume all cache-only layers belong to the same KV cache group,
        # thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata

        cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
            self._determine_batch_execution_and_padding(num_tokens)
        )
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

        with (
            set_forward_context(
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
                slot_mapping=self._get_slot_mapping(
                    num_input_tokens, common_attn_metadata.slot_mapping
                ),
            ),
            (
                KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
                if has_kv_transfer_group()
                else nullcontext()
            ) as kv_connector_output,
        ):
            self.model(
                hidden_states=self.hidden_states[:num_input_tokens],
            )

        # Return the sampled tokens as "draft" tokens
        # Shape: [batch_size, 1] to match num_speculative_tokens=1
        return sampled_token_ids.unsqueeze(-1), kv_connector_output

    def _get_slot_mapping(
        self,
        num_tokens: int,
        slot_mapping: torch.Tensor | None = None,
    ) -> dict[str, torch.Tensor]:
        """Return slot_mapping dict for cache-only attention layers.

        If slot_mapping is provided, copies it into the buffer first.
        """
        if slot_mapping is not None:
            num_actual = slot_mapping.shape[0]
            self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
            if num_tokens > num_actual:
                self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

        view = self._slot_mapping_buffer[:num_tokens]
        return {name: view for name in self.attn_layer_names}

    def _determine_batch_execution_and_padding(
        self,
        num_tokens: int,
        use_cudagraphs: bool = True,
    ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
        cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
            num_tokens,
            valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
        )
        num_tokens_padded = batch_desc.num_tokens

        # Extra coordination when running data-parallel since we need to
        # coordinate across ranks
        # TODO(Flechman): support DBO ubatching
        should_ubatch, num_tokens_across_dp = False, None
        if self.vllm_config.parallel_config.data_parallel_size > 1:
            should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
                coordinate_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    parallel_config=self.vllm_config.parallel_config,
                    allow_microbatching=False,
                    num_tokens_padded=num_tokens_padded,
                    cudagraph_mode=cudagraph_mode.value,
                )
            )
            assert not should_ubatch, (
                "DBO ubatching not implemented for extract_hidden_states"
            )

            # Extract DP-synced values
            if num_tokens_across_dp is not None:
                dp_rank = self.dp_rank
                num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
                # Re-dispatch with DP padding so we have the correct
                # batch_descriptor
                cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
                    num_tokens_padded,
                    valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
                )
                # Assert to make sure the agreed upon token count is correct
                # otherwise num_tokens_across_dp will no-longer be valid
                assert batch_desc.num_tokens == num_tokens_padded
                num_tokens_across_dp[dp_rank] = num_tokens_padded

        return cudagraph_mode, num_tokens_padded, num_tokens_across_dp

    def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
        """Initialize cudagraph dispatcher keys.

        Only supports PIECEWISE cudagraphs (via mixed_mode).
        Should be called after adjust_cudagraph_sizes_for_spec_decode.
        """
        assert self.vllm_config.speculative_config is not None
        if (
            not self.vllm_config.speculative_config.enforce_eager
            and cudagraph_mode.mixed_mode()
            in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
        ):
            proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
        else:
            proposer_cudagraph_mode = CUDAGraphMode.NONE

        self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)

    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
        slot_mappings: dict[str, torch.Tensor] | None = None,
    ) -> None:
        assert self.model is not None, "Model must be initialized before dummy_run"
        cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
            self._determine_batch_execution_and_padding(
                num_tokens, use_cudagraphs=use_cudagraphs
            )
        )

        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

        # Use our own slot mapping buffer during cudagraph capture.
        if (
            self.attn_layer_names
            and slot_mappings is not None
            and self.attn_layer_names[0] in slot_mappings
        ):
            slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
        else:
            slot_mapping_dict = slot_mappings or {}

        with set_forward_context(
            None,
            self.vllm_config,
            num_tokens=num_input_tokens,
            num_tokens_across_dp=num_tokens_across_dp,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            slot_mapping=slot_mapping_dict,
        ):
            self.model(
                hidden_states=self.hidden_states[:num_input_tokens],
            )

    def _build_attn_metadata_builder(
        self, draft_attn_layers: dict[str, AttentionLayerBase]
    ) -> AttentionMetadataBuilder:
        """Build the attention metadata builder from draft attention layers."""
        if not draft_attn_layers:
            raise ValueError("No attention layers found for ExtractHiddenStatesModel")
        layer = next(iter(draft_attn_layers.values()))
        attn_backend = layer.get_attn_backend()
        return attn_backend.get_builder_cls()(
            layer.get_kv_cache_spec(self.vllm_config),
            self.attn_layer_names,
            self.vllm_config,
            self.device,
        )

    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        discard_request_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare next token IDs for speculative decoding.

        Since num_speculative_tokens == 1, sampled_token_ids has shape
        (batch_size, 1). For each request we either use the sampled token
        (if valid and not discarded) or a backup token from the request state.
        """
        num_reqs = gpu_input_batch.num_reqs
        device = sampled_token_ids.device

        # Compute backup tokens for discarded / invalid requests
        backup_tokens_gpu = torch.tensor(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
            ],
            dtype=torch.int32,
            device=device,
        )

        assert discard_request_mask.dtype == torch.bool

        # With num_speculative_tokens == 1, there is exactly one token
        sampled = sampled_token_ids[:, 0]
        is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
        valid_sampled_tokens_count = is_valid.to(torch.int32)

        use_sampled = is_valid & ~discard_request_mask[:num_reqs]
        next_token_ids = torch.where(
            use_sampled, sampled.to(torch.int32), backup_tokens_gpu
        )

        return next_token_ids, valid_sampled_tokens_count

    def load_model(self, target_model: nn.Module) -> None:
        """Load the ExtractHiddenStatesModel model.

        This method instantiates the ExtractHiddenStatesModel model which is used
        to cache hidden states during speculative decoding. The model uses
        cache-only attention (no computation, just caching KV states).

        Args:
            target_model: The target model (passed for compatibility with
                         EagleProposer interface, but not used here)
        """
        # Get the target model's attention layers before loading draft model
        target_attn_layer_names = set(
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()  # type: ignore[type-abstract]
        )

        assert self.vllm_config.speculative_config is not None
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
        from vllm.compilation.backends import set_model_tag

        with set_model_tag("extract_hidden_states"):
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )

        # Identify draft model's attention layers (difference from target)
        all_attn_layers = get_layers_from_vllm_config(
            self.vllm_config,
            AttentionLayerBase,  # type: ignore[type-abstract]
        )
        draft_attn_layers = {
            name: layer
            for name, layer in all_attn_layers.items()
            if name not in target_attn_layer_names
        }
        self.attn_layer_names = list(draft_attn_layers.keys())
        assert len(draft_attn_layers) == 1, (
            "ExtractHiddenStatesModel should have exactly one "
            f"attention layer, found {len(draft_attn_layers)}"
        )
        self.attn_metadata_builder = self._build_attn_metadata_builder(
            draft_attn_layers
        )

    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
        """Validate all drafting layers belong to the same KV cache group.

        With exactly one attention layer (asserted in load_model), this is
        trivially satisfied.
        """
        assert len(self.attn_layer_names) == 1

_build_attn_metadata_builder

_build_attn_metadata_builder(
    draft_attn_layers: dict[str, AttentionLayerBase],
) -> AttentionMetadataBuilder

Build the attention metadata builder from draft attention layers.

Source code in vllm/v1/spec_decode/extract_hidden_states.py
def _build_attn_metadata_builder(
    self, draft_attn_layers: dict[str, AttentionLayerBase]
) -> AttentionMetadataBuilder:
    """Build the attention metadata builder from draft attention layers."""
    if not draft_attn_layers:
        raise ValueError("No attention layers found for ExtractHiddenStatesModel")
    layer = next(iter(draft_attn_layers.values()))
    attn_backend = layer.get_attn_backend()
    return attn_backend.get_builder_cls()(
        layer.get_kv_cache_spec(self.vllm_config),
        self.attn_layer_names,
        self.vllm_config,
        self.device,
    )

_get_slot_mapping

_get_slot_mapping(
    num_tokens: int, slot_mapping: Tensor | None = None
) -> dict[str, Tensor]

Return slot_mapping dict for cache-only attention layers.

If slot_mapping is provided, copies it into the buffer first.

Source code in vllm/v1/spec_decode/extract_hidden_states.py
def _get_slot_mapping(
    self,
    num_tokens: int,
    slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
    """Return slot_mapping dict for cache-only attention layers.

    If slot_mapping is provided, copies it into the buffer first.
    """
    if slot_mapping is not None:
        num_actual = slot_mapping.shape[0]
        self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
        if num_tokens > num_actual:
            self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)

    view = self._slot_mapping_buffer[:num_tokens]
    return {name: view for name in self.attn_layer_names}

initialize_cudagraph_keys

initialize_cudagraph_keys(
    cudagraph_mode: CUDAGraphMode,
) -> None

Initialize cudagraph dispatcher keys.

Only supports PIECEWISE cudagraphs (via mixed_mode). Should be called after adjust_cudagraph_sizes_for_spec_decode.

Source code in vllm/v1/spec_decode/extract_hidden_states.py
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
    """Initialize cudagraph dispatcher keys.

    Only supports PIECEWISE cudagraphs (via mixed_mode).
    Should be called after adjust_cudagraph_sizes_for_spec_decode.
    """
    assert self.vllm_config.speculative_config is not None
    if (
        not self.vllm_config.speculative_config.enforce_eager
        and cudagraph_mode.mixed_mode()
        in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
    ):
        proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
    else:
        proposer_cudagraph_mode = CUDAGraphMode.NONE

    self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)

load_model

load_model(target_model: Module) -> None

Load the ExtractHiddenStatesModel model.

This method instantiates the ExtractHiddenStatesModel model which is used to cache hidden states during speculative decoding. The model uses cache-only attention (no computation, just caching KV states).

Parameters:

Name Type Description Default
target_model Module

The target model (passed for compatibility with EagleProposer interface, but not used here)

required
Source code in vllm/v1/spec_decode/extract_hidden_states.py
def load_model(self, target_model: nn.Module) -> None:
    """Load the ExtractHiddenStatesModel model.

    This method instantiates the ExtractHiddenStatesModel model which is used
    to cache hidden states during speculative decoding. The model uses
    cache-only attention (no computation, just caching KV states).

    Args:
        target_model: The target model (passed for compatibility with
                     EagleProposer interface, but not used here)
    """
    # Get the target model's attention layers before loading draft model
    target_attn_layer_names = set(
        get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()  # type: ignore[type-abstract]
    )

    assert self.vllm_config.speculative_config is not None
    draft_model_config = self.vllm_config.speculative_config.draft_model_config
    from vllm.compilation.backends import set_model_tag

    with set_model_tag("extract_hidden_states"):
        self.model = get_model(
            vllm_config=self.vllm_config, model_config=draft_model_config
        )

    # Identify draft model's attention layers (difference from target)
    all_attn_layers = get_layers_from_vllm_config(
        self.vllm_config,
        AttentionLayerBase,  # type: ignore[type-abstract]
    )
    draft_attn_layers = {
        name: layer
        for name, layer in all_attn_layers.items()
        if name not in target_attn_layer_names
    }
    self.attn_layer_names = list(draft_attn_layers.keys())
    assert len(draft_attn_layers) == 1, (
        "ExtractHiddenStatesModel should have exactly one "
        f"attention layer, found {len(draft_attn_layers)}"
    )
    self.attn_metadata_builder = self._build_attn_metadata_builder(
        draft_attn_layers
    )

prepare_next_token_ids_padded

prepare_next_token_ids_padded(
    common_attn_metadata: CommonAttentionMetadata,
    sampled_token_ids: Tensor,
    requests: dict[str, CachedRequestState],
    gpu_input_batch: InputBatch,
    discard_request_mask: Tensor,
) -> tuple[Tensor, Tensor]

Prepare next token IDs for speculative decoding.

Since num_speculative_tokens == 1, sampled_token_ids has shape (batch_size, 1). For each request we either use the sampled token (if valid and not discarded) or a backup token from the request state.

Source code in vllm/v1/spec_decode/extract_hidden_states.py
def prepare_next_token_ids_padded(
    self,
    common_attn_metadata: CommonAttentionMetadata,
    sampled_token_ids: torch.Tensor,
    requests: dict[str, CachedRequestState],
    gpu_input_batch: InputBatch,
    discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Prepare next token IDs for speculative decoding.

    Since num_speculative_tokens == 1, sampled_token_ids has shape
    (batch_size, 1). For each request we either use the sampled token
    (if valid and not discarded) or a backup token from the request state.
    """
    num_reqs = gpu_input_batch.num_reqs
    device = sampled_token_ids.device

    # Compute backup tokens for discarded / invalid requests
    backup_tokens_gpu = torch.tensor(
        [
            requests[gpu_input_batch.req_ids[i]].get_token_id(
                common_attn_metadata.seq_lens_cpu[i].item()
            )
            for i in range(num_reqs)
        ],
        dtype=torch.int32,
        device=device,
    )

    assert discard_request_mask.dtype == torch.bool

    # With num_speculative_tokens == 1, there is exactly one token
    sampled = sampled_token_ids[:, 0]
    is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
    valid_sampled_tokens_count = is_valid.to(torch.int32)

    use_sampled = is_valid & ~discard_request_mask[:num_reqs]
    next_token_ids = torch.where(
        use_sampled, sampled.to(torch.int32), backup_tokens_gpu
    )

    return next_token_ids, valid_sampled_tokens_count

propose

propose(
    sampled_token_ids: Tensor,
    target_hidden_states: list[Tensor],
    common_attn_metadata: CommonAttentionMetadata,
    scheduler_output: SchedulerOutput,
    slot_mappings: dict[str, Tensor]
    | list[dict[str, Tensor]]
    | None = None,
) -> tuple[Tensor, KVConnectorOutput | None]

Propose draft tokens by calling the ExtractHiddenStatesModel model.

The ExtractHiddenStatesModel caches the hidden states in the KV cache without performing actual attention computation. This allows us to extract and store hidden states for later use (e.g., KV transfer).

This proposer doesn't actually perform speculation - it returns the sampled tokens as "draft" tokens, ensuring they always verify (match). The main purpose is to cache hidden states, not to speculate.

Parameters:

Name Type Description Default
sampled_token_ids Tensor

Sampled token IDs from the target model

required
target_hidden_states list[Tensor]

List of hidden state tensors from target model (one per aux hidden state layer)

required
common_attn_metadata CommonAttentionMetadata

Attention metadata

required
scheduler_output SchedulerOutput

Scheduler output for KV connector

required
slot_mappings dict[str, Tensor] | list[dict[str, Tensor]] | None

Slot mappings for KV cache (unused, provided for interface compatibility)

None

Returns:

Type Description
tuple[Tensor, KVConnectorOutput | None]

Tuple of: - Draft tokens matching sampled tokens, shape [batch_size, 1] - KV connector output (if KV transfer is active), else None

Source code in vllm/v1/spec_decode/extract_hidden_states.py
def propose(
    self,
    sampled_token_ids: torch.Tensor,
    target_hidden_states: list[torch.Tensor],
    common_attn_metadata: CommonAttentionMetadata,
    scheduler_output: SchedulerOutput,
    slot_mappings: dict[str, torch.Tensor]
    | list[dict[str, torch.Tensor]]
    | None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
    """Propose draft tokens by calling the ExtractHiddenStatesModel model.

    The ExtractHiddenStatesModel caches the hidden states in the KV cache
    without performing actual attention computation. This allows us to
    extract and store hidden states for later use (e.g., KV transfer).

    This proposer doesn't actually perform speculation - it returns the
    sampled tokens as "draft" tokens, ensuring they always verify (match).
    The main purpose is to cache hidden states, not to speculate.

    Args:
        sampled_token_ids: Sampled token IDs from the target model
        target_hidden_states: List of hidden state tensors from target model
                            (one per aux hidden state layer)
        common_attn_metadata: Attention metadata
        scheduler_output: Scheduler output for KV connector
        slot_mappings: Slot mappings for KV cache (unused, provided for
                      interface compatibility)

    Returns:
        Tuple of:
            - Draft tokens matching sampled tokens, shape [batch_size, 1]
            - KV connector output (if KV transfer is active), else None
    """
    assert self.model is not None and isinstance(target_hidden_states, list)

    # target_hidden_states is a list of tensors (one per layer)
    # Each tensor has shape [num_tokens, hidden_size]
    # Stack to shape: [num_tokens, num_hidden_states, hidden_size]
    stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
    num_tokens = stacked_hidden_states.shape[0]

    # Copy hidden states to buffer
    self.hidden_states[:num_tokens] = stacked_hidden_states

    assert self.attn_metadata_builder is not None
    attn_metadata = self.attn_metadata_builder.build_for_drafting(
        common_attn_metadata=common_attn_metadata, draft_index=0
    )

    # We assume all cache-only layers belong to the same KV cache group,
    # thus using the same attention metadata.
    per_layer_attn_metadata = {}
    for layer_name in self.attn_layer_names:
        per_layer_attn_metadata[layer_name] = attn_metadata

    cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
        self._determine_batch_execution_and_padding(num_tokens)
    )
    if num_tokens_across_dp is not None:
        num_tokens_across_dp[self.dp_rank] = num_input_tokens

    with (
        set_forward_context(
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
            num_tokens_across_dp=num_tokens_across_dp,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
            slot_mapping=self._get_slot_mapping(
                num_input_tokens, common_attn_metadata.slot_mapping
            ),
        ),
        (
            KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
            if has_kv_transfer_group()
            else nullcontext()
        ) as kv_connector_output,
    ):
        self.model(
            hidden_states=self.hidden_states[:num_input_tokens],
        )

    # Return the sampled tokens as "draft" tokens
    # Shape: [batch_size, 1] to match num_speculative_tokens=1
    return sampled_token_ids.unsqueeze(-1), kv_connector_output

validate_same_kv_cache_group

validate_same_kv_cache_group(
    kv_cache_config: KVCacheConfig,
) -> None

Validate all drafting layers belong to the same KV cache group.

With exactly one attention layer (asserted in load_model), this is trivially satisfied.

Source code in vllm/v1/spec_decode/extract_hidden_states.py
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
    """Validate all drafting layers belong to the same KV cache group.

    With exactly one attention layer (asserted in load_model), this is
    trivially satisfied.
    """
    assert len(self.attn_layer_names) == 1