vllm.model_executor.layers.typical_acceptance_sampler
TypicalAcceptanceSampler
¶
Bases: SpecDecodeDeterministicBaseSampler
Apply typical acceptance sampling as described in section 3.3.1 in "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" https://arxiv.org/pdf/2401.10774
Source code in vllm/model_executor/layers/typical_acceptance_sampler.py
|
|
__init__
¶
Create a Typical Acceptance Sampler.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
strict_mode
|
bool
|
Whether or not to perform shape/device/dtype checks |
False
|
posterior_threshold
|
A threshold value that sets a lower bound |
required | |
posterior_alpha
|
A scaling factor for the entropy-based |
required |
Source code in vllm/model_executor/layers/typical_acceptance_sampler.py
_evaluate_accepted_tokens
¶
Evaluates and returns a mask of accepted tokens based on the posterior probabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_probs
|
Tensor
|
A tensor of shape (batch_size, k, vocab_size) representing the probabilities of each token in the vocabulary for each position in the proposed sequence. This is the distribution generated by the target model. |
required |
draft_token_ids
|
Tensor
|
A tensor of shape (batch_size, k) representing the proposed token ids. |
required |
A draft token_id x_{n+k} is accepted if it satisfies the following condition
where corresponds to target_probs and and correspond to hyperparameters specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given draft token ids based on the provided target probabilities. It calculates the entropy of the posterior distribution and determines a dynamic threshold for each token position using the provided posterior_threshold and posterior_alpha values. The method then returns a boolean mask indicating which tokens can be accepted.
Returns:
Type | Description |
---|---|
torch.Tensor: A boolean tensor of shape (batch_size, k) where each element indicates whether the corresponding draft token has been accepted or rejected. True indicates acceptance and false indicates rejection. |
Source code in vllm/model_executor/layers/typical_acceptance_sampler.py
_get_recovered_token_ids
¶
The recovered token ids will fill the first unmatched token by the target token.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_probs
|
Tensor
|
A tensor of shape (batch_size, k, vocab_size) containing the target probability distribution. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: A tensor of shape (batch_size, k) with the recovered token ids which are selected from target probs. |
Source code in vllm/model_executor/layers/typical_acceptance_sampler.py
forward
¶
forward(
target_with_bonus_probs: Tensor,
bonus_token_ids: Tensor,
draft_probs: Tensor,
draft_token_ids: Tensor,
) -> Tensor
Sample token ids using typical acceptance sampling. This accepts or rejects tokens proposed by the draft model using the probability of each token according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be accepted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_probs
|
The probability distribution over token ids given context according to the target model. |
required | |
bonus_token_ids
|
Tensor
|
The "bonus" token ids that are accepted iff all speculative tokens in a sequence are accepted. |
required |
draft_probs
|
Tensor
|
This parameter is unused by the acceptance sampler. |
required |
draft_token_ids
|
Tensor
|
The token ids that were sampled from the draft probabilities. |
required |
Returns:
Name | Type | Description |
---|---|---|
output_token_ids |
Tensor
|
The token ids sampled via rejection sampling, or -1 if unable to sample a token because the previous token was rejected. |
Tensor
|
shape = [batch_size, num_speculative_tokens + num_bonus_tokens] |