Skip to content

vllm.spec_decode.mlp_speculator_worker

MLPSpeculatorWorker

Bases: NonLLMProposerWorkerBase, MultiStepWorker

Worker for MLPSpeculator models.

Not currently compatible with LoRA or chunked prefill.

Source code in vllm/spec_decode/mlp_speculator_worker.py
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
    """Worker for MLPSpeculator models.

    Not currently compatible with LoRA or chunked prefill.
    """

    @torch.inference_mode()
    def sampler_output(
        self,
        execute_model_req: ExecuteModelRequest,
        sample_len: int,
        # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
        # therefore does not need this parameter.
        seq_ids_with_bonus_token_in_last_step: Set[int],
    ) -> Tuple[List[SamplerOutput], bool]:
        """Run the model forward pass to generate sample_len future tokens.
        Returns the list of sampler output, one per layer, along with indicator
        of whether torch tensor in sampler output need to be transposed in
        latter sampler_output_to_torch logic.

        For mlp spec worker, this indicator shall be True.
        """
        self._raise_if_unsupported(execute_model_req)

        seq_group_metadata_list = execute_model_req.seq_group_metadata_list

        (input_tokens, seq_lens,
         query_lens) = self._prepare_input_tensors(seq_group_metadata_list)

        generators = self.model_runner.get_generators(
            execute_model_req.finished_requests_ids)
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list, seq_lens, query_lens, self.device,
            self.model_runner.pin_memory, generators)

        model_outputs = self.model_runner.model.generate_proposals(
            input_ids=input_tokens,
            previous_hidden_states=execute_model_req.previous_hidden_states.
            hidden_states,
            num_predict_tokens=sample_len,
            sampling_metadata=sampling_metadata)

        assert len(model_outputs) == sample_len

        return model_outputs, True

    def _prepare_input_tensors(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
    ) -> Tuple[torch.Tensor, List[int], List[int]]:
        if not seq_group_metadata_list:
            return torch.empty(0, device=self.device), [], []

        input_tokens: List[int] = []
        seq_lens: List[int] = []
        query_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            is_prompt = seq_group_metadata.is_prompt

            for seq_data in seq_group_metadata.seq_data.values():
                seq_data_len = seq_data.get_len()
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                    seq_len = min(
                        seq_data_len,
                        context_len + seq_group_metadata.token_chunk_size)
                    tokens = seq_data.get_token_ids()[context_len:seq_len]
                    seq_lens.append(seq_len)
                    input_tokens.extend(tokens)
                    query_lens.append(seq_len - context_len)
                else:
                    seq_lens.append(seq_data_len)
                    input_tokens.append(seq_data.get_last_token_id())
                    query_lens.append(1)

        input_tokens_tensor = torch.tensor(input_tokens,
                                           dtype=torch.long,
                                           device=self.device)
        return input_tokens_tensor, seq_lens, query_lens

_prepare_input_tensors

_prepare_input_tensors(
    seq_group_metadata_list: Optional[
        List[SequenceGroupMetadata]
    ],
) -> Tuple[Tensor, List[int], List[int]]
Source code in vllm/spec_decode/mlp_speculator_worker.py
def _prepare_input_tensors(
    self,
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, List[int], List[int]]:
    if not seq_group_metadata_list:
        return torch.empty(0, device=self.device), [], []

    input_tokens: List[int] = []
    seq_lens: List[int] = []
    query_lens: List[int] = []

    for seq_group_metadata in seq_group_metadata_list:
        is_prompt = seq_group_metadata.is_prompt

        for seq_data in seq_group_metadata.seq_data.values():
            seq_data_len = seq_data.get_len()
            if is_prompt:
                context_len = seq_data.get_num_computed_tokens()
                seq_len = min(
                    seq_data_len,
                    context_len + seq_group_metadata.token_chunk_size)
                tokens = seq_data.get_token_ids()[context_len:seq_len]
                seq_lens.append(seq_len)
                input_tokens.extend(tokens)
                query_lens.append(seq_len - context_len)
            else:
                seq_lens.append(seq_data_len)
                input_tokens.append(seq_data.get_last_token_id())
                query_lens.append(1)

    input_tokens_tensor = torch.tensor(input_tokens,
                                       dtype=torch.long,
                                       device=self.device)
    return input_tokens_tensor, seq_lens, query_lens

sampler_output

sampler_output(
    execute_model_req: ExecuteModelRequest,
    sample_len: int,
    seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]

Run the model forward pass to generate sample_len future tokens. Returns the list of sampler output, one per layer, along with indicator of whether torch tensor in sampler output need to be transposed in latter sampler_output_to_torch logic.

For mlp spec worker, this indicator shall be True.

Source code in vllm/spec_decode/mlp_speculator_worker.py
@torch.inference_mode()
def sampler_output(
    self,
    execute_model_req: ExecuteModelRequest,
    sample_len: int,
    # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and
    # therefore does not need this parameter.
    seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
    """Run the model forward pass to generate sample_len future tokens.
    Returns the list of sampler output, one per layer, along with indicator
    of whether torch tensor in sampler output need to be transposed in
    latter sampler_output_to_torch logic.

    For mlp spec worker, this indicator shall be True.
    """
    self._raise_if_unsupported(execute_model_req)

    seq_group_metadata_list = execute_model_req.seq_group_metadata_list

    (input_tokens, seq_lens,
     query_lens) = self._prepare_input_tensors(seq_group_metadata_list)

    generators = self.model_runner.get_generators(
        execute_model_req.finished_requests_ids)
    sampling_metadata = SamplingMetadata.prepare(
        seq_group_metadata_list, seq_lens, query_lens, self.device,
        self.model_runner.pin_memory, generators)

    model_outputs = self.model_runner.model.generate_proposals(
        input_ids=input_tokens,
        previous_hidden_states=execute_model_req.previous_hidden_states.
        hidden_states,
        num_predict_tokens=sample_len,
        sampling_metadata=sampling_metadata)

    assert len(model_outputs) == sample_len

    return model_outputs, True