vllm.v1.sample.ops.topk_topp_sampler
TopKTopPSampler
¶
Bases: Module
Module that performs optional top-k and top-p filtering followed by weighted random sampling of logits.
Implementations may update the logits tensor in-place.
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
__init__
¶
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
forward_cuda
¶
forward_cuda(
logits: Tensor,
generators: dict[int, Generator],
k: Optional[Tensor],
p: Optional[Tensor],
) -> Tensor
More optimized implementation for top-k and top-p sampling.
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
forward_native
¶
forward_native(
logits: Tensor,
generators: dict[int, Generator],
k: Optional[Tensor],
p: Optional[Tensor],
) -> Tensor
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
apply_top_k_only
¶
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
apply_top_k_top_p
¶
Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor, which can be slow for large batches.
The logits tensor may be updated in-place.
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
apply_top_k_top_p_tpu
¶
Apply top-k and top-p optimized for TPU.
This algorithm avoids using torch.scatter which is extremely slow on TPU. This is achieved by finding a "cut-off" element in the original logit, and after thresholding the logit using this cut-off, the remaining elements shall constitute the top-p set.
Note: in the case of tie (i.e. multipple cut-off elements present in the logit), all tie elements are included in the top-p set. In other words, this function does not break ties. Instead, these tie tokens have equal chance of being chosen during final sampling, so we can consider the tie being broken then.
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
flashinfer_sample
¶
flashinfer_sample(
logits: Tensor,
k: Optional[Tensor],
p: Optional[Tensor],
generators: dict[int, Generator],
) -> Tensor
Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the random_sample
function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the random_sample
function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while random_sample
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
Source code in vllm/v1/sample/ops/topk_topp_sampler.py
random_sample
¶
Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial causes CPU-GPU synchronization.