Skip to content

vllm.worker.multi_step_worker

MultiStepState dataclass

Source code in vllm/worker/multi_step_worker.py
@dataclass
class MultiStepState:
    worker_input: WorkerInput
    model_input: StatefulModelInput

model_input instance-attribute

model_input: StatefulModelInput

worker_input instance-attribute

worker_input: WorkerInput

__init__

__init__(
    worker_input: WorkerInput,
    model_input: StatefulModelInput,
) -> None

MultiStepWorker

Bases: Worker

Source code in vllm/worker/multi_step_worker.py
class MultiStepWorker(Worker):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        base_model_runner = self.model_runner
        # for multi-step model, wrap the model runner with MultiStepModelRunner
        self.model_runner = MultiStepModelRunner(
            base_model_runner,
            vllm_config=base_model_runner.vllm_config,
            kv_cache_dtype=self.cache_config.cache_dtype,
            is_driver_worker=base_model_runner.is_driver_worker,
        )

        pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
        self.multi_step_states: List[
            Optional[MultiStepState]] = [None] * pipeline_parallel_size
        self.temp_output = None

    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
    ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
        """
        Get the driver input and broadcast it to other workers.
        """
        assert self.is_driver_worker
        virtual_engine = execute_model_req.virtual_engine
        is_first_multi_step = execute_model_req.is_first_multi_step
        if is_first_multi_step:
            # on first step we prepare the worker input and model input normally
            worker_input: WorkerInput = self.prepare_worker_input(
                execute_model_req=execute_model_req)
            model_input: StatefulModelInput = (
                self.model_runner.prepare_model_input(
                    execute_model_req.seq_group_metadata_list,
                    execute_model_req.virtual_engine,
                    execute_model_req.finished_requests_ids))

            if execute_model_req.async_callback:
                model_input.frozen_model_input = dataclasses.replace(  # type: ignore
                    model_input.frozen_model_input,
                    async_callback=execute_model_req.async_callback)
        else:
            # on subsequent steps we reuse the worker input and model input
            multi_step_state = self.multi_step_states[virtual_engine]
            worker_input = multi_step_state.worker_input
            model_input = multi_step_state.model_input
            frozen_model_input = model_input.frozen_model_input
            assert frozen_model_input is not None
            assert frozen_model_input.attn_metadata is not None
            # clear the cached metadata so that it can be recomputed on
            # the workers.
            frozen_model_input.attn_metadata._cached_prefill_metadata = None
            frozen_model_input.attn_metadata._cached_decode_metadata = None

        model_input.is_first_multi_step = is_first_multi_step
        model_input.is_last_step = execute_model_req.is_last_step

        if not is_first_multi_step:
            # we broadcast the last sampled token ids to all TP workers so they
            # can update their model input metadata in-place.
            self._prepare_last_sampled_token_ids_for_tp_workers(
                execute_model_req=execute_model_req, model_input=model_input)

        if self.do_metadata_broadcast:
            broadcast_data = worker_input.as_broadcastable_tensor_dict()
            broadcast_data.update(model_input.as_broadcastable_tensor_dict())
            broadcast_tensor_dict(broadcast_data, src=0)

        # Retuning empty dict here to keep this compatible with
        # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
        return model_input, worker_input, {}

    def _prepare_last_sampled_token_ids_for_tp_workers(
        self,
        execute_model_req: ExecuteModelRequest,
        model_input: StatefulModelInput,
    ) -> None:
        """ 
        Prepare the last sampled token ids for TP workers. If it's the last 
        PP rank, then the last sampled token ids are already in the model_input.
        If it is NOT the last PP rank, then we need to get the last sampled
        token that is cached in the execute_model_req.
        """
        if get_pp_group().is_last_rank:
            assert model_input.cached_outputs[
                -1].sampler_output.sampled_token_ids is None
            assert model_input.cached_outputs[-1].sampled_token_ids is not None
            model_input.last_sampled_token_ids = model_input.cached_outputs[
                -1].sampled_token_ids
            # free sampled token ids from the previous step if it has been
            # pythonized. Cannot free the last sampled token ids because
            # we need it for GPU advance_step.
            for output in model_input.cached_outputs[:-1]:
                if output.pythonized:
                    output.sampled_token_ids = None
        else:
            # otherwise we need to get the cached sampled token ids from the
            # execute_model_req
            assert execute_model_req.last_sampled_token_ids is not None
            model_input.last_sampled_token_ids = (
                execute_model_req.last_sampled_token_ids.cuda())
            model_input.add_sampler_output(
                SamplerOutput(outputs=[], sampled_token_ids=None),
                model_input.last_sampled_token_ids)

            # free sampled token ids from the previous step.
            # TODO(will) we could reuse the sampled token ids tensor from
            # the previous step instead.
            for output in model_input.cached_outputs[:-1]:
                output.sampled_token_ids = None
            assert model_input.cached_outputs[-1].sampled_token_ids is not None

    def prepare_input(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None,
    ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
                                                              torch.Tensor]]]:
        """
        Depending on the current state of the request and multi step worker,
        this method may skip the normal _prepare_model_input and
        _prepare_worker_input methods and instead used cached values.
        """
        if self.is_driver_worker:
            if execute_model_req is None:
                if self.do_metadata_broadcast:
                    # This signals that there's no more requests to process for
                    # now. All workers are running infinite loop with
                    # broadcast_tensor_dict, and it stops the loop when the
                    # driver broadcasts an empty input. Send an empty input to
                    # notify all other workers to stop their execution loop.
                    broadcast_tensor_dict({}, src=0)
                return None

            virtual_engine = execute_model_req.virtual_engine
            (model_input, worker_input,
             kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
            assert isinstance(model_input, StatefulModelInput)
            if execute_model_req.is_first_multi_step:
                # cache the worker input and model input for the next steps
                self.multi_step_states[virtual_engine] = MultiStepState(
                    worker_input=worker_input, model_input=model_input)
        # if TP workers
        else:
            broadcast_data = self._get_worker_input_from_broadcast()
            # if the driver has sent an empty input, we should stop the worker
            # loop
            if broadcast_data is None:
                return None
            model_input, worker_input, kwargs = broadcast_data
            assert isinstance(model_input, StatefulModelInput)
            virtual_engine = worker_input.virtual_engine
            if model_input.is_first_multi_step:
                pass
                # TODO(will) Can cache the worker input and model input for the
                # next steps. See below for details
            else:
                # TODO(will) possible to also cache and reuse the cached worker
                # input and model input. The idea is essentially the delta
                # optimization for model_inputs. Where the TP workers can cache
                # the model input states and we only broadcast the delta need
                # for the next step (sampled_token_ids from the previous step)

                assert isinstance(model_input, StatefulModelInput)
                # we need to update the last sampled token ids in the model
                # input for the workers so that they can run inplace
                # advance_step
                model_input.add_sampler_output(
                    SamplerOutput(outputs=[], sampled_token_ids=None),
                    model_input.last_sampled_token_ids)

        assert model_input is not None
        assert worker_input is not None
        return model_input, worker_input, kwargs

model_runner instance-attribute

model_runner = MultiStepModelRunner(
    base_model_runner,
    vllm_config=vllm_config,
    kv_cache_dtype=cache_dtype,
    is_driver_worker=is_driver_worker,
)

multi_step_states instance-attribute

multi_step_states: List[Optional[MultiStepState]] = [
    None
] * pipeline_parallel_size

temp_output instance-attribute

temp_output = None

__init__

__init__(*args, **kwargs)
Source code in vllm/worker/multi_step_worker.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    base_model_runner = self.model_runner
    # for multi-step model, wrap the model runner with MultiStepModelRunner
    self.model_runner = MultiStepModelRunner(
        base_model_runner,
        vllm_config=base_model_runner.vllm_config,
        kv_cache_dtype=self.cache_config.cache_dtype,
        is_driver_worker=base_model_runner.is_driver_worker,
    )

    pipeline_parallel_size = self.parallel_config.pipeline_parallel_size
    self.multi_step_states: List[
        Optional[MultiStepState]] = [None] * pipeline_parallel_size
    self.temp_output = None

_get_driver_input_and_broadcast

_get_driver_input_and_broadcast(
    execute_model_req: ExecuteModelRequest,
) -> Tuple[
    BroadcastableModelInput, WorkerInput, Dict[str, Tensor]
]

Get the driver input and broadcast it to other workers.

Source code in vllm/worker/multi_step_worker.py
def _get_driver_input_and_broadcast(
    self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
    """
    Get the driver input and broadcast it to other workers.
    """
    assert self.is_driver_worker
    virtual_engine = execute_model_req.virtual_engine
    is_first_multi_step = execute_model_req.is_first_multi_step
    if is_first_multi_step:
        # on first step we prepare the worker input and model input normally
        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: StatefulModelInput = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.virtual_engine,
                execute_model_req.finished_requests_ids))

        if execute_model_req.async_callback:
            model_input.frozen_model_input = dataclasses.replace(  # type: ignore
                model_input.frozen_model_input,
                async_callback=execute_model_req.async_callback)
    else:
        # on subsequent steps we reuse the worker input and model input
        multi_step_state = self.multi_step_states[virtual_engine]
        worker_input = multi_step_state.worker_input
        model_input = multi_step_state.model_input
        frozen_model_input = model_input.frozen_model_input
        assert frozen_model_input is not None
        assert frozen_model_input.attn_metadata is not None
        # clear the cached metadata so that it can be recomputed on
        # the workers.
        frozen_model_input.attn_metadata._cached_prefill_metadata = None
        frozen_model_input.attn_metadata._cached_decode_metadata = None

    model_input.is_first_multi_step = is_first_multi_step
    model_input.is_last_step = execute_model_req.is_last_step

    if not is_first_multi_step:
        # we broadcast the last sampled token ids to all TP workers so they
        # can update their model input metadata in-place.
        self._prepare_last_sampled_token_ids_for_tp_workers(
            execute_model_req=execute_model_req, model_input=model_input)

    if self.do_metadata_broadcast:
        broadcast_data = worker_input.as_broadcastable_tensor_dict()
        broadcast_data.update(model_input.as_broadcastable_tensor_dict())
        broadcast_tensor_dict(broadcast_data, src=0)

    # Retuning empty dict here to keep this compatible with
    # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
    return model_input, worker_input, {}

_prepare_last_sampled_token_ids_for_tp_workers

_prepare_last_sampled_token_ids_for_tp_workers(
    execute_model_req: ExecuteModelRequest,
    model_input: StatefulModelInput,
) -> None

Prepare the last sampled token ids for TP workers. If it's the last PP rank, then the last sampled token ids are already in the model_input. If it is NOT the last PP rank, then we need to get the last sampled token that is cached in the execute_model_req.

Source code in vllm/worker/multi_step_worker.py
def _prepare_last_sampled_token_ids_for_tp_workers(
    self,
    execute_model_req: ExecuteModelRequest,
    model_input: StatefulModelInput,
) -> None:
    """ 
    Prepare the last sampled token ids for TP workers. If it's the last 
    PP rank, then the last sampled token ids are already in the model_input.
    If it is NOT the last PP rank, then we need to get the last sampled
    token that is cached in the execute_model_req.
    """
    if get_pp_group().is_last_rank:
        assert model_input.cached_outputs[
            -1].sampler_output.sampled_token_ids is None
        assert model_input.cached_outputs[-1].sampled_token_ids is not None
        model_input.last_sampled_token_ids = model_input.cached_outputs[
            -1].sampled_token_ids
        # free sampled token ids from the previous step if it has been
        # pythonized. Cannot free the last sampled token ids because
        # we need it for GPU advance_step.
        for output in model_input.cached_outputs[:-1]:
            if output.pythonized:
                output.sampled_token_ids = None
    else:
        # otherwise we need to get the cached sampled token ids from the
        # execute_model_req
        assert execute_model_req.last_sampled_token_ids is not None
        model_input.last_sampled_token_ids = (
            execute_model_req.last_sampled_token_ids.cuda())
        model_input.add_sampler_output(
            SamplerOutput(outputs=[], sampled_token_ids=None),
            model_input.last_sampled_token_ids)

        # free sampled token ids from the previous step.
        # TODO(will) we could reuse the sampled token ids tensor from
        # the previous step instead.
        for output in model_input.cached_outputs[:-1]:
            output.sampled_token_ids = None
        assert model_input.cached_outputs[-1].sampled_token_ids is not None

prepare_input

prepare_input(
    execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[
    Tuple[
        StatefulModelInput, WorkerInput, Dict[str, Tensor]
    ]
]

Depending on the current state of the request and multi step worker, this method may skip the normal _prepare_model_input and _prepare_worker_input methods and instead used cached values.

Source code in vllm/worker/multi_step_worker.py
def prepare_input(
    self,
    execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
                                                          torch.Tensor]]]:
    """
    Depending on the current state of the request and multi step worker,
    this method may skip the normal _prepare_model_input and
    _prepare_worker_input methods and instead used cached values.
    """
    if self.is_driver_worker:
        if execute_model_req is None:
            if self.do_metadata_broadcast:
                # This signals that there's no more requests to process for
                # now. All workers are running infinite loop with
                # broadcast_tensor_dict, and it stops the loop when the
                # driver broadcasts an empty input. Send an empty input to
                # notify all other workers to stop their execution loop.
                broadcast_tensor_dict({}, src=0)
            return None

        virtual_engine = execute_model_req.virtual_engine
        (model_input, worker_input,
         kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
        assert isinstance(model_input, StatefulModelInput)
        if execute_model_req.is_first_multi_step:
            # cache the worker input and model input for the next steps
            self.multi_step_states[virtual_engine] = MultiStepState(
                worker_input=worker_input, model_input=model_input)
    # if TP workers
    else:
        broadcast_data = self._get_worker_input_from_broadcast()
        # if the driver has sent an empty input, we should stop the worker
        # loop
        if broadcast_data is None:
            return None
        model_input, worker_input, kwargs = broadcast_data
        assert isinstance(model_input, StatefulModelInput)
        virtual_engine = worker_input.virtual_engine
        if model_input.is_first_multi_step:
            pass
            # TODO(will) Can cache the worker input and model input for the
            # next steps. See below for details
        else:
            # TODO(will) possible to also cache and reuse the cached worker
            # input and model input. The idea is essentially the delta
            # optimization for model_inputs. Where the TP workers can cache
            # the model input states and we only broadcast the delta need
            # for the next step (sampled_token_ids from the previous step)

            assert isinstance(model_input, StatefulModelInput)
            # we need to update the last sampled token ids in the model
            # input for the workers so that they can run inplace
            # advance_step
            model_input.add_sampler_output(
                SamplerOutput(outputs=[], sampled_token_ids=None),
                model_input.last_sampled_token_ids)

    assert model_input is not None
    assert worker_input is not None
    return model_input, worker_input, kwargs