Skip to content

vllm.spec_decode.draft_model_runner

allow_gpu_advance_step module-attribute

allow_gpu_advance_step = True

debug_advance_input module-attribute

debug_advance_input = False

logger module-attribute

logger = init_logger(__name__)

TP1DraftModelRunner

Bases: ModelRunnerWrapperBase

Specialized model runner for speculative decoding draft model. Since the draft model always execute k forward passes consecutively to generate k speculative tokens in a single speculative decoding step, we could get rid of most CPU-GPU synchronization and data transfer overheads by keeping model input and output tensors on GPU all the time.

TODOs: 1. Currently supports only flash-attn, add support for other attn_backends. 2. Support TP > 1 (this requires some designs because we do not expect any broadcasting inside execute_model).

Source code in vllm/spec_decode/draft_model_runner.py
class TP1DraftModelRunner(ModelRunnerWrapperBase):
    """Specialized model runner for speculative decoding draft model.
    Since the draft model always execute k forward passes consecutively to
    generate k speculative tokens in a single speculative decoding step,
    we could get rid of most CPU-GPU synchronization and data transfer
    overheads by keeping model input and output tensors on GPU all the time.

    TODOs:
    1. Currently supports only flash-attn, add support for other attn_backends.
    2. Support TP > 1 (this requires some designs because we do not expect
       any broadcasting inside execute_model).
    """

    def __init__(self, model_runner: ModelRunnerBase):
        super().__init__(model_runner)

        self.indices_of_seq_with_bonus_tokens = None

    def _update_sampling_metadata(self, sampling_metadata, num_seqs,
                                  num_queries):

        assert sampling_metadata.num_prompts == 0
        assert len(sampling_metadata.seq_groups) == num_queries
        assert sampling_metadata.selected_token_indices.shape == (
            num_queries, )
        # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501

        # Verify that all sequences are decodes
        for i in range(num_queries):
            seq_group = sampling_metadata.seq_groups[i]

            assert seq_group.is_prompt is False  # No prompt
            assert seq_group.prompt_logprob_indices == []  # No prompt
            assert seq_group.sample_indices == [i]  # Simple

    def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
                          last_output: SamplerOutput) -> ModelRunnerInputBase:
        # Currently, we expect "decode mode" only
        assert not model_input.is_prompt

        # Get num_seqs
        num_seqs = len(model_input.seq_lens)
        num_queries = len(model_input.query_lens)

        # Get output tokens GPU tensor
        sampled_token_ids = last_output.sampled_token_ids
        assert sampled_token_ids is not None

        # Update attn_metadata
        attn_metadata = model_input.attn_metadata
        assert isinstance(attn_metadata, FlashAttentionMetadata)

        attn_metadata.advance_step(model_input, sampled_token_ids,
                                   self.block_size, num_seqs, num_queries)

        # Update sampling_metadata
        sampling_metadata = model_input.sampling_metadata
        self._update_sampling_metadata(sampling_metadata, num_seqs,
                                       num_queries)

        # Create new input
        new_model_input = self._model_input_cls(
            input_tokens=model_input.input_tokens,
            input_positions=model_input.input_positions,
            attn_metadata=attn_metadata,
            seq_lens=attn_metadata.seq_lens,
            query_lens=model_input.query_lens,
            lora_mapping=model_input.lora_mapping,
            lora_requests=model_input.lora_requests,
            multi_modal_kwargs=model_input.multi_modal_kwargs,
            sampling_metadata=model_input.sampling_metadata,
            is_prompt=False,
        )

        # Ensure we skip CPU samples
        assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
        # We can reuse sampling tensors since every decode iteration is the same
        new_model_input.sampling_metadata.reuse_sampling_tensors = True

        if debug_advance_input:
            logger.debug("NEW INPUT: ")
            logger.debug("  input_tokens = %s", new_model_input.input_tokens)
            logger.debug("  input_positions = %s",
                         new_model_input.input_positions)
            logger.debug("  seq_lens = %d", new_model_input.seq_lens)
            logger.debug("  query_lens = %d", new_model_input.query_lens)
            logger.debug("  attn_metadata:")
            logger.debug("    seq_lens_tensor: %s",
                         attn_metadata.seq_lens_tensor)
            logger.debug("    slot_mapping: %s", attn_metadata.slot_mapping)
            logger.debug("    block_tables: %s", attn_metadata.block_tables)

        return new_model_input

    def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
        """Determines if draft_model_runner GPU multi-step can be used.
        Currently required conditions are:
            1. Only decodes
            2. Only flash-attn
            3. No LORA
            4. No prompt_adapter_config
        """
        if not allow_gpu_advance_step:
            return False

        # We allow multi-step GPU only in decode mode
        for seq_group in execute_model_req.seq_group_metadata_list:
            if seq_group.is_prompt:
                return False

        # TODO: Add support for other attn backends
        if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
            return False

        # TODO: Add support for LORA
        if self.lora_config:
            return False

        # TODO: Add soft-tuning prompt adapter support
        return not self.prompt_adapter_config

    def set_indices_of_seq_with_bonus_tokens(self,
                                             indices_of_seq_with_bonus_tokens):
        self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelRunnerInputBase,
        kv_caches: List[torch.Tensor],
        previous_hidden_states: Optional[torch.Tensor] = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
        **kwargs,
    ) -> Optional[List[SamplerOutput]]:
        """Executes num_steps forward passes with advacement of input tensors
        on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.

        Optimizations used:
            1. Input tensors are updated on the GPU directly
            2. Skips GPU=>CPU serialization of sampler outputs (we don't need
                them since we do batch expansion later that uses GPU outputs)
            3. Reuses sampling tensors (since we run only decodes and they have
                a repeating sampling logic)
        """

        # When num_steps == 1, we execute the fallback here for the GPU
        # advance_step, which runs prepare_inputs on CPU and for each spec
        # iteration invokes this function only once
        # (Look at multi-step-worker code)
        is_fallback = num_steps == 1
        if not is_fallback:
            # Since we do not broadcast data inside execute_model anymore,
            # we need to figure out the best way to support TP > 1 in this
            # case, because we will at least need to broadcast the sampled
            # tokens to all workers.
            if not self.is_driver_worker:
                raise ValueError("TP1DraftModelRunner only supports TP=1.")

            # Sanity
            if self.lora_config is not None:
                raise ValueError("TP1DraftModelRunner has no support for LORA")
            if self.prompt_adapter_config is not None:
                raise ValueError("TP1DraftModelRunner has no support for "
                                 "prompt_adapter_config")
            if model_input.inputs_embeds is not None:
                raise ValueError("TP1DraftModelRunner has no support for "
                                 "inputs_embeds")
            if model_input.multi_modal_kwargs:
                raise ValueError(
                    "TP1DraftModelRunner has no support for multi_modal_kwargs"
                )
        else:
            if self.lora_config:
                assert model_input.lora_requests is not None
                assert model_input.lora_mapping is not None
                self.set_active_loras(model_input.lora_requests,
                                      model_input.lora_mapping)

            if self.prompt_adapter_config:
                assert model_input.prompt_adapter_requests is not None
                assert model_input.prompt_adapter_mapping is not None
                self.set_active_prompt_adapters(
                    model_input.prompt_adapter_requests,
                    model_input.prompt_adapter_mapping)

            self.attn_state.begin_forward(model_input)

        # Detect exec mode
        assert model_input.attn_metadata is not None
        use_cuda_graph = False
        if model_input.attn_metadata.num_prefills > 0:
            # In this case, execute_model(..) was called directly
            if num_steps > 1:
                raise ValueError(
                    "execute_model(..) of draft_model_runner can be called "
                    "directly only with a single-step prefill")
        else:
            # We can skip CPU samples for spec token generation.
            # (We do allow CPU samples for num_steps == 1 to support the
            # fallback case, where supports_gpu_multi_step(..) does not pass)
            model_input.sampling_metadata.skip_sampler_cpu_output = (
                not is_fallback)

            # Attn attr defines if we use cuda graphs
            use_cuda_graph = model_input.attn_metadata.use_cuda_graph

        # Get model
        if use_cuda_graph:
            if model_input.inputs_embeds is None:
                graph_batch_size = model_input.input_tokens.shape[0]
                model_executable = (
                    self.graph_runners[model_input.virtual_engine][(
                        graph_batch_size, False)])
            else:
                graph_batch_size = model_input.inputs_embeds.shape[0]
                model_executable = (
                    self.graph_runners[model_input.virtual_engine][(
                        graph_batch_size, True)])

            if previous_hidden_states is not None:
                hidden_states = torch.cat([
                    previous_hidden_states,
                    torch.empty([
                        graph_batch_size - previous_hidden_states.shape[0],
                        *previous_hidden_states.shape[1:]
                    ],
                                dtype=previous_hidden_states.dtype,
                                device=previous_hidden_states.device)
                ])
            else:
                hidden_states = None
        else:
            model_executable = self.model
            hidden_states = previous_hidden_states

        outputs: List[SamplerOutput] = []
        for step in range(num_steps):
            multi_modal_kwargs = model_input.multi_modal_kwargs or {}

            model_execute_kwargs = {"previous_hidden_states": hidden_states} \
                if previous_hidden_states is not None else {}

            compute_logits_kwargs = {}
            # Run model
            if hasattr(self.model.config, "num_nextn_predict_layers"):
                # for DeepSeek MTP only to use the corresponding layer for
                # each step
                spec_step_idx = kwargs.get("spec_step_idx", step)
                model_execute_kwargs["spec_step_idx"] = spec_step_idx
                compute_logits_kwargs["spec_step_idx"] = spec_step_idx
            with set_forward_context(model_input.attn_metadata,
                                     self.vllm_config):
                hidden_states = model_executable(
                    input_ids=model_input.input_tokens,
                    inputs_embeds=None,
                    positions=model_input.input_positions,
                    intermediate_tensors=intermediate_tensors,
                    **MultiModalKwargs.as_kwargs(
                        multi_modal_kwargs,
                        device=self.device,
                    ),
                    **model_execute_kwargs,
                )

            # Compute the logits.
            logits = self.model.compute_logits(hidden_states,
                                               model_input.sampling_metadata,
                                               **compute_logits_kwargs)
            if not self.is_driver_worker:
                return []
            # Sample the next token.
            output = self.model_runner.sampler(
                logits=logits,
                sampling_metadata=model_input.sampling_metadata,
            )
            outputs.append(output)

            if self.return_hidden_states and is_fallback:
                if use_cuda_graph:
                    indices = model_input.sampling_metadata\
                      .selected_token_indices
                    output.hidden_states = hidden_states[:len(indices)]
                else:
                    output.hidden_states = hidden_states

            if model_input.attn_metadata.num_prefills == 0 \
                and self.indices_of_seq_with_bonus_tokens is not None:
                assert output.sampled_token_ids is not None
                # output.sampled_token_ids should be of shape (num_seqs, 1)
                nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
                assert num_tokens_per_seq == 1
                count = 0
                for i in range(nums_seqs):
                    bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
                        count]
                    if i != bonus_seq_idx:
                        # The following might cause a cpu->gpu sync
                        # However, the performance impact is negligible as we
                        # benchmarked on H100.
                        output.sampled_token_ids[
                            i, :] = model_input.input_tokens[bonus_seq_idx]
                    else:
                        count += 1

            # Prepare inputs for the next step
            if step != num_steps - 1:
                model_input = self._gpu_advance_step(model_input, outputs[-1])

        return outputs

indices_of_seq_with_bonus_tokens instance-attribute

indices_of_seq_with_bonus_tokens = None

__init__

__init__(model_runner: ModelRunnerBase)
Source code in vllm/spec_decode/draft_model_runner.py
def __init__(self, model_runner: ModelRunnerBase):
    super().__init__(model_runner)

    self.indices_of_seq_with_bonus_tokens = None

_gpu_advance_step

_gpu_advance_step(
    model_input: ModelRunnerInputBase,
    last_output: SamplerOutput,
) -> ModelRunnerInputBase
Source code in vllm/spec_decode/draft_model_runner.py
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
                      last_output: SamplerOutput) -> ModelRunnerInputBase:
    # Currently, we expect "decode mode" only
    assert not model_input.is_prompt

    # Get num_seqs
    num_seqs = len(model_input.seq_lens)
    num_queries = len(model_input.query_lens)

    # Get output tokens GPU tensor
    sampled_token_ids = last_output.sampled_token_ids
    assert sampled_token_ids is not None

    # Update attn_metadata
    attn_metadata = model_input.attn_metadata
    assert isinstance(attn_metadata, FlashAttentionMetadata)

    attn_metadata.advance_step(model_input, sampled_token_ids,
                               self.block_size, num_seqs, num_queries)

    # Update sampling_metadata
    sampling_metadata = model_input.sampling_metadata
    self._update_sampling_metadata(sampling_metadata, num_seqs,
                                   num_queries)

    # Create new input
    new_model_input = self._model_input_cls(
        input_tokens=model_input.input_tokens,
        input_positions=model_input.input_positions,
        attn_metadata=attn_metadata,
        seq_lens=attn_metadata.seq_lens,
        query_lens=model_input.query_lens,
        lora_mapping=model_input.lora_mapping,
        lora_requests=model_input.lora_requests,
        multi_modal_kwargs=model_input.multi_modal_kwargs,
        sampling_metadata=model_input.sampling_metadata,
        is_prompt=False,
    )

    # Ensure we skip CPU samples
    assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
    # We can reuse sampling tensors since every decode iteration is the same
    new_model_input.sampling_metadata.reuse_sampling_tensors = True

    if debug_advance_input:
        logger.debug("NEW INPUT: ")
        logger.debug("  input_tokens = %s", new_model_input.input_tokens)
        logger.debug("  input_positions = %s",
                     new_model_input.input_positions)
        logger.debug("  seq_lens = %d", new_model_input.seq_lens)
        logger.debug("  query_lens = %d", new_model_input.query_lens)
        logger.debug("  attn_metadata:")
        logger.debug("    seq_lens_tensor: %s",
                     attn_metadata.seq_lens_tensor)
        logger.debug("    slot_mapping: %s", attn_metadata.slot_mapping)
        logger.debug("    block_tables: %s", attn_metadata.block_tables)

    return new_model_input

_update_sampling_metadata

_update_sampling_metadata(
    sampling_metadata, num_seqs, num_queries
)
Source code in vllm/spec_decode/draft_model_runner.py
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
                              num_queries):

    assert sampling_metadata.num_prompts == 0
    assert len(sampling_metadata.seq_groups) == num_queries
    assert sampling_metadata.selected_token_indices.shape == (
        num_queries, )
    # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501

    # Verify that all sequences are decodes
    for i in range(num_queries):
        seq_group = sampling_metadata.seq_groups[i]

        assert seq_group.is_prompt is False  # No prompt
        assert seq_group.prompt_logprob_indices == []  # No prompt
        assert seq_group.sample_indices == [i]  # Simple

execute_model

execute_model(
    model_input: ModelRunnerInputBase,
    kv_caches: List[Tensor],
    previous_hidden_states: Optional[Tensor] = None,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    num_steps: int = 1,
    **kwargs,
) -> Optional[List[SamplerOutput]]

Executes num_steps forward passes with advacement of input tensors on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.

Optimizations used
  1. Input tensors are updated on the GPU directly
  2. Skips GPU=>CPU serialization of sampler outputs (we don't need them since we do batch expansion later that uses GPU outputs)
  3. Reuses sampling tensors (since we run only decodes and they have a repeating sampling logic)
Source code in vllm/spec_decode/draft_model_runner.py
@torch.inference_mode()
def execute_model(
    self,
    model_input: ModelRunnerInputBase,
    kv_caches: List[torch.Tensor],
    previous_hidden_states: Optional[torch.Tensor] = None,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    num_steps: int = 1,
    **kwargs,
) -> Optional[List[SamplerOutput]]:
    """Executes num_steps forward passes with advacement of input tensors
    on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.

    Optimizations used:
        1. Input tensors are updated on the GPU directly
        2. Skips GPU=>CPU serialization of sampler outputs (we don't need
            them since we do batch expansion later that uses GPU outputs)
        3. Reuses sampling tensors (since we run only decodes and they have
            a repeating sampling logic)
    """

    # When num_steps == 1, we execute the fallback here for the GPU
    # advance_step, which runs prepare_inputs on CPU and for each spec
    # iteration invokes this function only once
    # (Look at multi-step-worker code)
    is_fallback = num_steps == 1
    if not is_fallback:
        # Since we do not broadcast data inside execute_model anymore,
        # we need to figure out the best way to support TP > 1 in this
        # case, because we will at least need to broadcast the sampled
        # tokens to all workers.
        if not self.is_driver_worker:
            raise ValueError("TP1DraftModelRunner only supports TP=1.")

        # Sanity
        if self.lora_config is not None:
            raise ValueError("TP1DraftModelRunner has no support for LORA")
        if self.prompt_adapter_config is not None:
            raise ValueError("TP1DraftModelRunner has no support for "
                             "prompt_adapter_config")
        if model_input.inputs_embeds is not None:
            raise ValueError("TP1DraftModelRunner has no support for "
                             "inputs_embeds")
        if model_input.multi_modal_kwargs:
            raise ValueError(
                "TP1DraftModelRunner has no support for multi_modal_kwargs"
            )
    else:
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

        self.attn_state.begin_forward(model_input)

    # Detect exec mode
    assert model_input.attn_metadata is not None
    use_cuda_graph = False
    if model_input.attn_metadata.num_prefills > 0:
        # In this case, execute_model(..) was called directly
        if num_steps > 1:
            raise ValueError(
                "execute_model(..) of draft_model_runner can be called "
                "directly only with a single-step prefill")
    else:
        # We can skip CPU samples for spec token generation.
        # (We do allow CPU samples for num_steps == 1 to support the
        # fallback case, where supports_gpu_multi_step(..) does not pass)
        model_input.sampling_metadata.skip_sampler_cpu_output = (
            not is_fallback)

        # Attn attr defines if we use cuda graphs
        use_cuda_graph = model_input.attn_metadata.use_cuda_graph

    # Get model
    if use_cuda_graph:
        if model_input.inputs_embeds is None:
            graph_batch_size = model_input.input_tokens.shape[0]
            model_executable = (
                self.graph_runners[model_input.virtual_engine][(
                    graph_batch_size, False)])
        else:
            graph_batch_size = model_input.inputs_embeds.shape[0]
            model_executable = (
                self.graph_runners[model_input.virtual_engine][(
                    graph_batch_size, True)])

        if previous_hidden_states is not None:
            hidden_states = torch.cat([
                previous_hidden_states,
                torch.empty([
                    graph_batch_size - previous_hidden_states.shape[0],
                    *previous_hidden_states.shape[1:]
                ],
                            dtype=previous_hidden_states.dtype,
                            device=previous_hidden_states.device)
            ])
        else:
            hidden_states = None
    else:
        model_executable = self.model
        hidden_states = previous_hidden_states

    outputs: List[SamplerOutput] = []
    for step in range(num_steps):
        multi_modal_kwargs = model_input.multi_modal_kwargs or {}

        model_execute_kwargs = {"previous_hidden_states": hidden_states} \
            if previous_hidden_states is not None else {}

        compute_logits_kwargs = {}
        # Run model
        if hasattr(self.model.config, "num_nextn_predict_layers"):
            # for DeepSeek MTP only to use the corresponding layer for
            # each step
            spec_step_idx = kwargs.get("spec_step_idx", step)
            model_execute_kwargs["spec_step_idx"] = spec_step_idx
            compute_logits_kwargs["spec_step_idx"] = spec_step_idx
        with set_forward_context(model_input.attn_metadata,
                                 self.vllm_config):
            hidden_states = model_executable(
                input_ids=model_input.input_tokens,
                inputs_embeds=None,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
                **MultiModalKwargs.as_kwargs(
                    multi_modal_kwargs,
                    device=self.device,
                ),
                **model_execute_kwargs,
            )

        # Compute the logits.
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata,
                                           **compute_logits_kwargs)
        if not self.is_driver_worker:
            return []
        # Sample the next token.
        output = self.model_runner.sampler(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )
        outputs.append(output)

        if self.return_hidden_states and is_fallback:
            if use_cuda_graph:
                indices = model_input.sampling_metadata\
                  .selected_token_indices
                output.hidden_states = hidden_states[:len(indices)]
            else:
                output.hidden_states = hidden_states

        if model_input.attn_metadata.num_prefills == 0 \
            and self.indices_of_seq_with_bonus_tokens is not None:
            assert output.sampled_token_ids is not None
            # output.sampled_token_ids should be of shape (num_seqs, 1)
            nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
            assert num_tokens_per_seq == 1
            count = 0
            for i in range(nums_seqs):
                bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
                    count]
                if i != bonus_seq_idx:
                    # The following might cause a cpu->gpu sync
                    # However, the performance impact is negligible as we
                    # benchmarked on H100.
                    output.sampled_token_ids[
                        i, :] = model_input.input_tokens[bonus_seq_idx]
                else:
                    count += 1

        # Prepare inputs for the next step
        if step != num_steps - 1:
            model_input = self._gpu_advance_step(model_input, outputs[-1])

    return outputs

set_indices_of_seq_with_bonus_tokens

set_indices_of_seq_with_bonus_tokens(
    indices_of_seq_with_bonus_tokens,
)
Source code in vllm/spec_decode/draft_model_runner.py
def set_indices_of_seq_with_bonus_tokens(self,
                                         indices_of_seq_with_bonus_tokens):
    self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens

supports_gpu_multi_step

supports_gpu_multi_step(
    execute_model_req: ExecuteModelRequest,
)

Determines if draft_model_runner GPU multi-step can be used. Currently required conditions are: 1. Only decodes 2. Only flash-attn 3. No LORA 4. No prompt_adapter_config

Source code in vllm/spec_decode/draft_model_runner.py
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
    """Determines if draft_model_runner GPU multi-step can be used.
    Currently required conditions are:
        1. Only decodes
        2. Only flash-attn
        3. No LORA
        4. No prompt_adapter_config
    """
    if not allow_gpu_advance_step:
        return False

    # We allow multi-step GPU only in decode mode
    for seq_group in execute_model_req.seq_group_metadata_list:
        if seq_group.is_prompt:
            return False

    # TODO: Add support for other attn backends
    if self.attn_backend.get_name() not in ("FLASH_ATTN", ):
        return False

    # TODO: Add support for LORA
    if self.lora_config:
        return False

    # TODO: Add soft-tuning prompt adapter support
    return not self.prompt_adapter_config