Skip to content

vllm.v1.engine.parallel_sampling

ParentRequest

Info, state & processing for parallel sampling request.

Store parent request ID and sampling params. Facilitate generating child request sampling params.

Source code in vllm/v1/engine/parallel_sampling.py
class ParentRequest:
    """Info, state & processing for parallel sampling request.

    Store parent request ID and sampling params.
    Facilitate generating child request sampling params.
    """

    request_id: str
    sampling_params: SamplingParams

    # To track the completion of child requests
    child_requests: set[str]

    # To aggregate child completions when not streaming
    output_aggregator: list[CompletionOutput]

    # To find the max number of generated tokens across all children
    max_num_generation_tokens: int

    # To efficiently obtain child sampling params
    cached_child_sampling_params: Optional[SamplingParams]

    def __init__(self, request_id: str,
                 sampling_params: SamplingParams) -> None:
        self.request_id = request_id
        self.sampling_params = sampling_params

        self.child_requests = set()
        self.output_aggregator = [None] * sampling_params.n if (
            sampling_params.output_kind
            == RequestOutputKind.FINAL_ONLY) else []
        self.max_num_generation_tokens = 0
        self.cached_child_sampling_params = None

    def _get_child_sampling_params(
        self,
        index: int,
    ) -> SamplingParams:
        """Efficiently obtain child `sampling_params`

        If `sampling_params.seed` is not `None` then 
        each child request requires a unique clone of
        parent `sampling_params` with a unique seed.

        Args:
          index: index within `n` child requests

        Returns:
          Child `sampling_params` instance.
        """
        seed = self.sampling_params.seed
        if self.cached_child_sampling_params:
            # Reuse child sampling_params data structure
            return self.cached_child_sampling_params
        # Build child sampling_params
        child_sampling_params = copy(self.sampling_params)
        child_sampling_params.n = 1
        if seed is None:
            # Cache child sampling_params for later reuse
            self.cached_child_sampling_params = child_sampling_params
        else:
            # Each child gets a clone with a unique seed
            child_sampling_params.seed = seed + index
        return child_sampling_params

    def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
        """Get child request ID and sampling params.

        Args:
          index: index within `n` child requests.

        Returns:
          (request ID, sampling_params) tuple
        """
        child_req_id = f"{index}_{self.request_id}"
        self.child_requests.add(child_req_id)
        return child_req_id, self._get_child_sampling_params(index)

    @property
    def n(self) -> int:
        return self.sampling_params.n

    def get_outputs(
        self,
        child_request_id: str,
        completion_output: CompletionOutput,
    ) -> tuple[str, list[CompletionOutput], bool]:
        if completion_output.finished():
            self.child_requests.remove(child_request_id)

        if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
            # If streaming, just return the current output.
            outputs = [completion_output]
        else:
            # If not streaming, aggregate the n final outputs.
            self.output_aggregator[completion_output.index] = completion_output
            outputs = [] if self.child_requests else self.output_aggregator

        finished = not self.child_requests
        return self.request_id, outputs, finished

    def observe_num_generation_tokens(self, num_generation_tokens: int):
        self.max_num_generation_tokens = max(num_generation_tokens,
                                             self.max_num_generation_tokens)
        return self.max_num_generation_tokens

    @staticmethod
    def observe_finished_request(parent_req: Optional['ParentRequest'],
                                 iteration_stats: IterationStats,
                                 num_generation_tokens: int):

        n_param = parent_req.n if parent_req is not None else 1

        if parent_req is not None:
            num_generation_tokens = parent_req.observe_num_generation_tokens(
                num_generation_tokens)

        # Child requests finished, we can now record to iteration stats
        if parent_req is None or not parent_req.child_requests:
            iteration_stats.max_num_generation_tokens_iter.append(
                num_generation_tokens)
            iteration_stats.n_params_iter.append(n_param)

cached_child_sampling_params instance-attribute

cached_child_sampling_params: Optional[SamplingParams] = (
    None
)

child_requests instance-attribute

child_requests: set[str] = set()

max_num_generation_tokens instance-attribute

max_num_generation_tokens: int = 0

n property

n: int

output_aggregator instance-attribute

output_aggregator: list[CompletionOutput] = (
    [None] * n if output_kind == FINAL_ONLY else []
)

request_id instance-attribute

request_id: str = request_id

sampling_params instance-attribute

sampling_params: SamplingParams = sampling_params

__init__

__init__(
    request_id: str, sampling_params: SamplingParams
) -> None
Source code in vllm/v1/engine/parallel_sampling.py
def __init__(self, request_id: str,
             sampling_params: SamplingParams) -> None:
    self.request_id = request_id
    self.sampling_params = sampling_params

    self.child_requests = set()
    self.output_aggregator = [None] * sampling_params.n if (
        sampling_params.output_kind
        == RequestOutputKind.FINAL_ONLY) else []
    self.max_num_generation_tokens = 0
    self.cached_child_sampling_params = None

_get_child_sampling_params

_get_child_sampling_params(index: int) -> SamplingParams

Efficiently obtain child sampling_params

If sampling_params.seed is not None then each child request requires a unique clone of parent sampling_params with a unique seed.

Parameters:

Name Type Description Default
index int

index within n child requests

required

Returns:

Type Description
SamplingParams

Child sampling_params instance.

Source code in vllm/v1/engine/parallel_sampling.py
def _get_child_sampling_params(
    self,
    index: int,
) -> SamplingParams:
    """Efficiently obtain child `sampling_params`

    If `sampling_params.seed` is not `None` then 
    each child request requires a unique clone of
    parent `sampling_params` with a unique seed.

    Args:
      index: index within `n` child requests

    Returns:
      Child `sampling_params` instance.
    """
    seed = self.sampling_params.seed
    if self.cached_child_sampling_params:
        # Reuse child sampling_params data structure
        return self.cached_child_sampling_params
    # Build child sampling_params
    child_sampling_params = copy(self.sampling_params)
    child_sampling_params.n = 1
    if seed is None:
        # Cache child sampling_params for later reuse
        self.cached_child_sampling_params = child_sampling_params
    else:
        # Each child gets a clone with a unique seed
        child_sampling_params.seed = seed + index
    return child_sampling_params

get_child_info

get_child_info(index: int) -> tuple[str, SamplingParams]

Get child request ID and sampling params.

Parameters:

Name Type Description Default
index int

index within n child requests.

required

Returns:

Type Description
tuple[str, SamplingParams]

(request ID, sampling_params) tuple

Source code in vllm/v1/engine/parallel_sampling.py
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
    """Get child request ID and sampling params.

    Args:
      index: index within `n` child requests.

    Returns:
      (request ID, sampling_params) tuple
    """
    child_req_id = f"{index}_{self.request_id}"
    self.child_requests.add(child_req_id)
    return child_req_id, self._get_child_sampling_params(index)

get_outputs

get_outputs(
    child_request_id: str,
    completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]
Source code in vllm/v1/engine/parallel_sampling.py
def get_outputs(
    self,
    child_request_id: str,
    completion_output: CompletionOutput,
) -> tuple[str, list[CompletionOutput], bool]:
    if completion_output.finished():
        self.child_requests.remove(child_request_id)

    if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
        # If streaming, just return the current output.
        outputs = [completion_output]
    else:
        # If not streaming, aggregate the n final outputs.
        self.output_aggregator[completion_output.index] = completion_output
        outputs = [] if self.child_requests else self.output_aggregator

    finished = not self.child_requests
    return self.request_id, outputs, finished

observe_finished_request staticmethod

observe_finished_request(
    parent_req: Optional[ParentRequest],
    iteration_stats: IterationStats,
    num_generation_tokens: int,
)
Source code in vllm/v1/engine/parallel_sampling.py
@staticmethod
def observe_finished_request(parent_req: Optional['ParentRequest'],
                             iteration_stats: IterationStats,
                             num_generation_tokens: int):

    n_param = parent_req.n if parent_req is not None else 1

    if parent_req is not None:
        num_generation_tokens = parent_req.observe_num_generation_tokens(
            num_generation_tokens)

    # Child requests finished, we can now record to iteration stats
    if parent_req is None or not parent_req.child_requests:
        iteration_stats.max_num_generation_tokens_iter.append(
            num_generation_tokens)
        iteration_stats.n_params_iter.append(n_param)

observe_num_generation_tokens

observe_num_generation_tokens(num_generation_tokens: int)
Source code in vllm/v1/engine/parallel_sampling.py
def observe_num_generation_tokens(self, num_generation_tokens: int):
    self.max_num_generation_tokens = max(num_generation_tokens,
                                         self.max_num_generation_tokens)
    return self.max_num_generation_tokens