Skip to content

vllm.worker.multi_step_neuronx_distributed_model_runner

MultiStepNeuronxDistributedModelRunner

Bases: NeuronxDistributedModelRunner

A model runner for multi-step decoding using the neuronx-distributed-inference framework

Source code in vllm/worker/multi_step_neuronx_distributed_model_runner.py
class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
    """A model runner for multi-step decoding using the
    neuronx-distributed-inference framework"""

    def __init__(
        self,
        vllm_config: VllmConfig,
    ):
        super().__init__(vllm_config)

    def load_model(self) -> None:
        from vllm.model_executor.model_loader.neuronx_distributed import (
            get_neuron_speculation_model)
        self.model = get_neuron_speculation_model(
            self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            speculation_config=self.speculative_config)

    @torch.inference_mode()
    def execute_model(
        self,
        model_input,
        kv_caches: Optional[List[torch.Tensor]] = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        sampling_params = torch.tensor([[
            seq_group.sampling_params.top_k,
            seq_group.sampling_params.top_p,
            seq_group.sampling_params.temperature,
        ] for seq_group in model_input.sampling_metadata.seq_groups])

        logits = self.model(
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            input_block_ids=model_input.input_block_ids,
            sampling_params=sampling_params,
            **MultiModalKwargs.as_kwargs(
                model_input.multi_modal_kwargs or {},
                device=self.device,
            ),
        )

        output = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )
        return output

__init__

__init__(vllm_config: VllmConfig)
Source code in vllm/worker/multi_step_neuronx_distributed_model_runner.py
def __init__(
    self,
    vllm_config: VllmConfig,
):
    super().__init__(vllm_config)

execute_model

execute_model(
    model_input,
    kv_caches: Optional[List[Tensor]] = None,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    num_steps: int = 1,
) -> Optional[List[SamplerOutput]]
Source code in vllm/worker/multi_step_neuronx_distributed_model_runner.py
@torch.inference_mode()
def execute_model(
    self,
    model_input,
    kv_caches: Optional[List[torch.Tensor]] = None,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
    sampling_params = torch.tensor([[
        seq_group.sampling_params.top_k,
        seq_group.sampling_params.top_p,
        seq_group.sampling_params.temperature,
    ] for seq_group in model_input.sampling_metadata.seq_groups])

    logits = self.model(
        input_ids=model_input.input_tokens,
        positions=model_input.input_positions,
        input_block_ids=model_input.input_block_ids,
        sampling_params=sampling_params,
        **MultiModalKwargs.as_kwargs(
            model_input.multi_modal_kwargs or {},
            device=self.device,
        ),
    )

    output = self.model.sample(
        logits=logits,
        sampling_metadata=model_input.sampling_metadata,
    )
    return output

load_model

load_model() -> None
Source code in vllm/worker/multi_step_neuronx_distributed_model_runner.py
def load_model(self) -> None:
    from vllm.model_executor.model_loader.neuronx_distributed import (
        get_neuron_speculation_model)
    self.model = get_neuron_speculation_model(
        self.model_config,
        parallel_config=self.parallel_config,
        scheduler_config=self.scheduler_config,
        speculation_config=self.speculative_config)