Skip to content

vllm.v1.sample.tpu.metadata

DEFAULT_SAMPLING_PARAMS module-attribute

DEFAULT_SAMPLING_PARAMS = dict(
    temperature=-1.0, min_p=0.0, top_k=0, top_p=1.0
)

TPUSupportedSamplingMetadata dataclass

Source code in vllm/v1/sample/tpu/metadata.py
@dataclass
class TPUSupportedSamplingMetadata:
    # This class exposes a more xla-friendly interface than SamplingMetadata
    # on TPU, in particular all arguments should be traceable and no optionals
    # are allowed, to avoid graph recompilation on Nones.
    temperature: torch.Tensor = None

    min_p: torch.Tensor = None
    top_k: torch.Tensor = None
    top_p: torch.Tensor = None

    all_greedy: bool = True

    # Whether logprobs are to be gathered in this batch of request. To balance
    # out compile time and runtime, a fixed `max_number_logprobs` value is used
    # when gathering logprobs, regardless of the values specified in the batch.
    logprobs: bool = False

    # TODO No penalties for now
    no_penalties: bool = True
    prompt_token_ids = None
    frequency_penalties = None
    presence_penalties = None
    repetition_penalties = None
    # should use tensor
    output_token_ids: list[list[int]] = field(default_factory=lambda: list())

    min_tokens = None  # impl is not vectorized

    logit_bias: list[Optional[dict[int, float]]] = field(
        default_factory=lambda: list())

    allowed_token_ids_mask = None
    bad_words_token_ids = None

    # Generator not supported by xla
    _generators: dict[int,
                      torch.Generator] = field(default_factory=lambda: dict())

    @property
    def generators(self) -> dict[int, torch.Generator]:
        # Generator not supported by torch/xla. This field must be immutable.
        return self._generators

    @classmethod
    def from_input_batch(
        cls,
        input_batch: InputBatch,
        padded_num_reqs: int,
        xla_device: torch.device,
        generate_params_if_all_greedy: bool = False
    ) -> "TPUSupportedSamplingMetadata":
        """
        Copy sampling tensors slices from `input_batch` to on device tensors.

        `InputBatch._make_sampling_metadata` causes recompilation on XLA as it 
        slices dynamic shapes on device tensors. This impl moves the dynamic 
        ops to CPU and produces tensors of fixed `padded_num_reqs` size.

        Args:
            input_batch: The input batch containing sampling parameters.
            padded_num_reqs: The padded number of requests.
            xla_device: The XLA device.
            generate_params_if_all_greedy: If True, generate sampling parameters
                even if all requests are greedy. this is useful for cases where
                we want to pre-compile a graph with sampling parameters, even if
                they are not strictly needed for greedy decoding.
        """
        needs_logprobs = input_batch.max_num_logprobs>0 if \
            input_batch.max_num_logprobs else False
        # Early return to avoid unnecessary cpu to tpu copy
        if (input_batch.all_greedy is True
                and generate_params_if_all_greedy is False):
            return cls(all_greedy=True, logprobs=needs_logprobs)

        num_reqs = input_batch.num_reqs

        def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
            # Pad value is the default one.
            cpu_tensor[num_reqs:padded_num_reqs] = fill_val

        fill_slice(input_batch.temperature_cpu_tensor,
                   DEFAULT_SAMPLING_PARAMS["temperature"])
        fill_slice(input_batch.min_p_cpu_tensor,
                   DEFAULT_SAMPLING_PARAMS["min_p"])
        fill_slice(input_batch.top_k_cpu_tensor,
                   DEFAULT_SAMPLING_PARAMS["top_k"])
        fill_slice(input_batch.top_p_cpu_tensor,
                   DEFAULT_SAMPLING_PARAMS["top_p"])

        # Slice persistent device tensors to a fixed pre-compiled padded shape.
        return cls(
            temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
            to(xla_device),
            all_greedy=input_batch.all_greedy,
            # TODO enable more and avoid returning None values
            top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
                xla_device),
            top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
                xla_device),
            min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
                xla_device),
            logprobs=needs_logprobs)

_generators class-attribute instance-attribute

_generators: dict[int, Generator] = field(
    default_factory=lambda: dict()
)

all_greedy class-attribute instance-attribute

all_greedy: bool = True

allowed_token_ids_mask class-attribute instance-attribute

allowed_token_ids_mask = None

bad_words_token_ids class-attribute instance-attribute

bad_words_token_ids = None

frequency_penalties class-attribute instance-attribute

frequency_penalties = None

generators property

generators: dict[int, Generator]

logit_bias class-attribute instance-attribute

logit_bias: list[Optional[dict[int, float]]] = field(
    default_factory=lambda: list()
)

logprobs class-attribute instance-attribute

logprobs: bool = False

min_p class-attribute instance-attribute

min_p: Tensor = None

min_tokens class-attribute instance-attribute

min_tokens = None

no_penalties class-attribute instance-attribute

no_penalties: bool = True

output_token_ids class-attribute instance-attribute

output_token_ids: list[list[int]] = field(
    default_factory=lambda: list()
)

presence_penalties class-attribute instance-attribute

presence_penalties = None

prompt_token_ids class-attribute instance-attribute

prompt_token_ids = None

repetition_penalties class-attribute instance-attribute

repetition_penalties = None

temperature class-attribute instance-attribute

temperature: Tensor = None

top_k class-attribute instance-attribute

top_k: Tensor = None

top_p class-attribute instance-attribute

top_p: Tensor = None

__init__

__init__(
    temperature: Tensor = None,
    min_p: Tensor = None,
    top_k: Tensor = None,
    top_p: Tensor = None,
    all_greedy: bool = True,
    logprobs: bool = False,
    no_penalties: bool = True,
    output_token_ids: list[list[int]] = lambda: list()(),
    logit_bias: list[
        Optional[dict[int, float]]
    ] = lambda: list()(),
    _generators: dict[int, Generator] = lambda: dict()(),
) -> None

from_input_batch classmethod

from_input_batch(
    input_batch: InputBatch,
    padded_num_reqs: int,
    xla_device: device,
    generate_params_if_all_greedy: bool = False,
) -> TPUSupportedSamplingMetadata

Copy sampling tensors slices from input_batch to on device tensors.

InputBatch._make_sampling_metadata causes recompilation on XLA as it slices dynamic shapes on device tensors. This impl moves the dynamic ops to CPU and produces tensors of fixed padded_num_reqs size.

Parameters:

Name Type Description Default
input_batch InputBatch

The input batch containing sampling parameters.

required
padded_num_reqs int

The padded number of requests.

required
xla_device device

The XLA device.

required
generate_params_if_all_greedy bool

If True, generate sampling parameters even if all requests are greedy. this is useful for cases where we want to pre-compile a graph with sampling parameters, even if they are not strictly needed for greedy decoding.

False
Source code in vllm/v1/sample/tpu/metadata.py
@classmethod
def from_input_batch(
    cls,
    input_batch: InputBatch,
    padded_num_reqs: int,
    xla_device: torch.device,
    generate_params_if_all_greedy: bool = False
) -> "TPUSupportedSamplingMetadata":
    """
    Copy sampling tensors slices from `input_batch` to on device tensors.

    `InputBatch._make_sampling_metadata` causes recompilation on XLA as it 
    slices dynamic shapes on device tensors. This impl moves the dynamic 
    ops to CPU and produces tensors of fixed `padded_num_reqs` size.

    Args:
        input_batch: The input batch containing sampling parameters.
        padded_num_reqs: The padded number of requests.
        xla_device: The XLA device.
        generate_params_if_all_greedy: If True, generate sampling parameters
            even if all requests are greedy. this is useful for cases where
            we want to pre-compile a graph with sampling parameters, even if
            they are not strictly needed for greedy decoding.
    """
    needs_logprobs = input_batch.max_num_logprobs>0 if \
        input_batch.max_num_logprobs else False
    # Early return to avoid unnecessary cpu to tpu copy
    if (input_batch.all_greedy is True
            and generate_params_if_all_greedy is False):
        return cls(all_greedy=True, logprobs=needs_logprobs)

    num_reqs = input_batch.num_reqs

    def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
        # Pad value is the default one.
        cpu_tensor[num_reqs:padded_num_reqs] = fill_val

    fill_slice(input_batch.temperature_cpu_tensor,
               DEFAULT_SAMPLING_PARAMS["temperature"])
    fill_slice(input_batch.min_p_cpu_tensor,
               DEFAULT_SAMPLING_PARAMS["min_p"])
    fill_slice(input_batch.top_k_cpu_tensor,
               DEFAULT_SAMPLING_PARAMS["top_k"])
    fill_slice(input_batch.top_p_cpu_tensor,
               DEFAULT_SAMPLING_PARAMS["top_p"])

    # Slice persistent device tensors to a fixed pre-compiled padded shape.
    return cls(
        temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].
        to(xla_device),
        all_greedy=input_batch.all_greedy,
        # TODO enable more and avoid returning None values
        top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(
            xla_device),
        top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(
            xla_device),
        min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
            xla_device),
        logprobs=needs_logprobs)