Skip to content

vllm.spec_decode.batch_expansion

DEFAULT_SIMPLE_SAMPLING_PARAMS module-attribute

DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()

SeqId module-attribute

SeqId = int

TargetSeqId module-attribute

TargetSeqId = int

TokenId module-attribute

TokenId = int

BatchExpansionTop1Scorer

Bases: SpeculativeScorer

Implements a speculative scorer that uses batch expansion to get probabilities of speculative tokens according to the scoring model.

Batch expansion converts a list of sequences and multiple query positions to a new batch of sequences, each with a single query position. This allows for MQA-like scoring in speculative decoding without requiring an MQA kernel.

It is strictly less efficient than MQA scoring.

It only supports scoring the top1 proposal tokens of the proposer, instead of topk/tree.

Source code in vllm/spec_decode/batch_expansion.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
class BatchExpansionTop1Scorer(SpeculativeScorer):
    """Implements a speculative scorer that uses batch expansion to get
    probabilities of speculative tokens according to the scoring model.

    Batch expansion converts a list of sequences and multiple query positions
    to a new batch of sequences, each with a single query position. This allows
    for MQA-like scoring in speculative decoding without requiring an MQA
    kernel.

    It is strictly less efficient than MQA scoring.

    It only supports scoring the top1 proposal tokens of the proposer, instead
    of topk/tree.
    """

    @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
    def score_proposals(
        self,
        execute_model_req: ExecuteModelRequest,
        proposals: SpeculativeProposals,
    ) -> SpeculativeScores:
        """Score the proposed tokens via the scorer model.

        This converts each input sequence to a set of k+1 target sequences. The
        target sequences have the unique continuations to be scored and a
        unique sequence ID that is different from all input sequence ids.

        If a speculative sequence length would exceed the max model length, then
        no speculation is produced for that sequence.

        Args:
            execute_model_req: The execution request.
            proposals: The speculative proposals to score.
        Returns:
            SpeculativeScores: The scores of each speculative token, along with
                which sequences were ignored during scoring.
        """

        # TODO(cade) perform this on GPU to remove blocking call.
        proposal_lens_list = proposals.proposal_lens.tolist()
        proposal_token_ids_list = proposals.proposal_token_ids.tolist()

        # Filter the list to ignore invalid proposals.
        proposal_token_ids_list_without_skips = [
            proposals for proposals in proposal_token_ids_list
            if VLLM_INVALID_TOKEN_ID not in proposals
        ]

        (spec_indices, non_spec_indices, target_seq_group_metadata_list,
         num_scoring_tokens) = self._expand_batch(
             seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
             proposal_token_ids_list=proposal_token_ids_list_without_skips,
             proposal_lens_list=proposal_lens_list,
         )

        target_sampler_output = self._scorer_worker.execute_model(
            execute_model_req=execute_model_req.clone(
                seq_group_metadata_list=target_seq_group_metadata_list))
        assert len(target_sampler_output) == 1, "expected single-step output"
        target_sampler_output = target_sampler_output[0]

        if not non_spec_indices:
            # All sequence groups in batch have spec decoding enabled
            return self._contract_batch_all_spec(
                target_sampler_output=target_sampler_output,
                proposals=proposals,
            )
        else:
            # Batch has a mix of spec decode enabled and disabled seq groups
            return self._contract_batch(
                execute_model_req.seq_group_metadata_list,
                target_sampler_output=target_sampler_output,
                proposals=proposals,
                num_scoring_tokens=num_scoring_tokens,
                non_spec_indices=non_spec_indices,
                spec_indices=spec_indices,
                k=execute_model_req.num_lookahead_slots,
            )

    def _expand_batch(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_token_ids_list: List[List[TokenId]],
        proposal_lens_list: List[int],
    ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
        """Given the input sequences and potentially multiple corresponding
        proposal tokens, create a new batch where each sequence has a single
        query token.
        """

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
            split_batch_by_proposal_len(
                seq_group_metadata_list, proposal_lens_list)

        spec_expanded_seqs = self._create_scoring_model_input(
            seq_group_metadata_list=spec_seqs,
            proposal_token_ids=proposal_token_ids_list,
            # NOTE: We determine the seq ids in the expanded batch using the
            # full seq_group_metadata_list, instead of only spec_seqs.
            target_seq_ids_iter=self._create_target_seq_id_iterator(
                seq_ids=get_all_seq_ids(seq_group_metadata_list)),
        )

        num_scoring_tokens = len(spec_expanded_seqs)
        # Batch speculative and non-speculative (e.g. chunked prefill) requests
        # but make sure order is prefill|decode due to backend requirement.
        target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs

        return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
                num_scoring_tokens)

    def _contract_non_speculative(
            self, scores: SpeculativeScores,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
            has_prompt_log: bool) -> SpeculativeScores:
        """
            Augment input `scores` with non-speculative requests outputs. 
            This includes decode requests with speculation turned off, as well
            as prefill requests when `enable_chunked_prefill` is set.
            For the latter, prefills are further separated into terminal and 
            non-terminal chunks (from which no token is sampled).
        """
        if not non_spec_indices:
            return scores

        if has_prompt_log:
            # When prompt_logprobs is enabled, prefills yield output token
            # (and respective prob) in the last entry (prompt|out):
            # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
            # With chunked prefill, non-terminal chunks have -1 on each
            # position: they're still picked, but they're discarded later.
            seq_meta = seq_group_metadata_list
            nospec_sizes = torch.tensor([
                seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
                for i in non_spec_indices
            ])
            nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
        else:
            # In this case only sampled tokens are returned, select all.
            nospec_sampled_token_idxs = list(
                range(len(non_spec_outputs.token_ids)))

        scores.token_ids[non_spec_indices, :1] = \
            non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
        scores.probs[non_spec_indices, :1, :] = \
            non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
        scores.logprobs[non_spec_indices, :1, :] = \
            non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
        if scores.hidden_states is not None:
            assert non_spec_outputs.hidden_states is not None
            scores.hidden_states[non_spec_indices, :1, :] = \
                non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
        return scores

    def _contract_batch(
            self,
            contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
            target_sampler_output: SamplerOutput,
            proposals: SpeculativeProposals, num_scoring_tokens: int,
            non_spec_indices: List[int], spec_indices: List[int],
            k: int) -> SpeculativeScores:
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.

        contracted_bs is the original batch size, and the batch size that the
        target_sampler_output will be contracted to.
        """
        contracted_bs = len(contracted_seq_group_metadata_list)
        (target_token_ids, target_probs, target_logprobs, target_hidden_states,
         non_spec_target_token_ids, non_spec_target_probs,
         non_spec_target_logprobs,
         non_spec_target_hidden_states) = self._split_scoring_output(
             target_sampler_output, num_scoring_tokens)

        # Map distinct sequences used to score each token
        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
        expanded_batch_size, k = proposals.proposal_token_ids.shape

        # The number of tokens in the expanded batch used for speculation is
        # equal to the total expanded batch size minus the number of samples for
        # non-speculative sequences, prefill chunks with no out tokens included
        non_spec_expanded_bs = len(non_spec_indices)
        spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs

        target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
        target_probs = target_probs.reshape(*target_token_ids.shape,
                                            self._vocab_size)
        target_logprobs = target_logprobs.reshape(target_probs.shape)

        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
                *target_token_ids.shape, target_hidden_states.shape[-1])

        all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
                                               fill_value=-1)
        all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
        all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                fill_value=-float("inf"))

        if target_sampler_output.hidden_states is not None:
            all_hidden_states = target_hidden_states.new_zeros(
                size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
        else:
            all_hidden_states = None

        has_prompt_log = any((sg.sampling_params.prompt_logprobs
                              and sg.sampling_params.prompt_logprobs > 0)
                             for sg in contracted_seq_group_metadata_list)
        # When prompt logprobs is enabled, lens of returned tensors go from
        # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
        # We adjust stride accordingly to get the generated tokens and
        # their probs, but pass on prompt_logprobs as is.
        prompt_logprobs = None
        if (not self._scorer_worker.model_runner.disable_logprobs\
            and has_prompt_log):
            prompt_logprobs = [
                o.prompt_logprobs for o in target_sampler_output.outputs
            ]
        elif not has_prompt_log:
            # When prompt logprobs are not to be returned,
            # we can ignore non-terminal chunks (no out token).
            non_spec_indices = [
                idx for idx in non_spec_indices
                if contracted_seq_group_metadata_list[idx].do_sample
            ]

        # "Contract" speculative.
        if spec_indices:
            all_tokens[spec_indices] = target_token_ids
            all_probs[spec_indices] = target_probs
            all_logprobs[spec_indices] = target_logprobs
            if all_hidden_states is not None:
                all_hidden_states[spec_indices] = target_hidden_states

        spec_scores = SpeculativeScores(probs=all_probs,
                                        token_ids=all_tokens,
                                        logprobs=all_logprobs,
                                        hidden_states=all_hidden_states,
                                        prompt_logprobs=prompt_logprobs)

        non_spec_outputs = SpeculativeScores(
            probs=non_spec_target_probs,
            token_ids=non_spec_target_token_ids,
            logprobs=non_spec_target_logprobs,
            hidden_states=non_spec_target_hidden_states)
        # Contract remaining nonspec entries based on non_spec_indices, if any.
        return self._contract_non_speculative(
            spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
            non_spec_outputs, has_prompt_log)

    def _contract_batch_all_spec(
        self,
        target_sampler_output: SamplerOutput,
        proposals: SpeculativeProposals,
    ) -> SpeculativeScores:
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.

        It assumes all sequences in the batch were previously expanded.
        """

        # Map distinct sequences used to score each token
        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
        contracted_bs, k = proposals.proposal_token_ids.shape

        # Reshape tensors to original batch size
        target_token_ids = target_sampler_output.sampled_token_ids.reshape(
            contracted_bs, k + 1)
        target_probs = target_sampler_output.sampled_token_probs.reshape(
            *target_token_ids.shape, self._vocab_size)
        target_logprobs = target_sampler_output.logprobs.reshape(
            target_probs.shape)
        target_hidden_states = target_sampler_output.hidden_states
        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
                *target_token_ids.shape, target_hidden_states.shape[-1])

        return SpeculativeScores(probs=target_probs,
                                 token_ids=target_token_ids,
                                 logprobs=target_logprobs,
                                 hidden_states=target_hidden_states,
                                 prompt_logprobs=None)

    def _create_scoring_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
        target_seq_ids_iter: Iterator[TargetSeqId],
    ) -> List[SequenceGroupMetadata]:
        """Given the original input sequences and proposed tokens from the draft
        model, create a list of target sequences that can be used for scoring.

        target_seq_ids_iter provides sequence ids for the expanded batch,
        fulfilling the requirement that no seq id in the expanded batch is equal
        to the seq id in the original batch.
        """

        if not seq_group_metadata_list:
            return []

        target_seq_group_metadata = list(
            chain.from_iterable(
                self._create_target_seq_group_metadata(
                    seq_group_metadata,
                    proposal_token_ids,
                    i,
                    target_seq_ids_iter,
                ) for i, seq_group_metadata in enumerate(
                    seq_group_metadata_list)))

        return target_seq_group_metadata

    def _create_target_seq_group_metadata(
        self,
        input_seq_group_metadata: SequenceGroupMetadata,
        proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
        batch_index: int,
        target_seq_ids_iter: Iterator[TargetSeqId],
    ) -> List[SequenceGroupMetadata]:
        """Given an input sequence group metadata and a list of draft tokens,
        create a list of target SequenceGroupMetadata, one for each
        token id that needs to be scored.

        Naive speculative decoding requires K target model scores, one for each
        draft model token. However one can add a bonus token such that if each
        token is accepted, then a final token may be sampled from the model.
        This function creates K+1 target SequenceGroupMetadata to take
        advantage of the bonus token.
        """
        assert len(input_seq_group_metadata.seq_data) == 1, (
            "Beam search "
            "not supported in speculative decoding")
        input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))

        token_ids_to_score = self._get_token_ids_to_score(
            proposal_token_ids[batch_index])

        sampling_params = input_seq_group_metadata.sampling_params
        target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
        for i, token_ids in enumerate(token_ids_to_score):
            target_seq_group_metadata_list.append(
                self._create_single_target_seq_group_metadata(
                    input_seq_group_metadata,
                    input_seq_id,
                    next(target_seq_ids_iter),
                    token_ids,
                    sampling_params=sampling_params,
                ))

        return target_seq_group_metadata_list

    @staticmethod
    def _create_single_target_seq_group_metadata(
        seq_group_metadata: SequenceGroupMetadata,
        seq_id: SeqId,
        target_seq_id: TargetSeqId,
        token_ids: List[TokenId],
        sampling_params: SamplingParams,
    ) -> SequenceGroupMetadata:
        """Create a single target SequenceGroupMetadata.

        Args:
            seq_group_metadata: The metadata for the input sequence.
            seq_id: The input sequence ID.
            target_seq_id: The corresponding target sequence ID.
            token_ids: The list of token ids that are to be appended to the
                input sequence.
        """
        seq_data = seq_group_metadata.seq_data[seq_id]
        prompt_token_ids = seq_data.prompt_token_ids_array
        new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
        mrope_position_delta = seq_data.mrope_position_delta

        new_seq_data_dict = {
            target_seq_id:
            SequenceData(
                prompt_token_ids,
                _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                        new_output_token_ids),
            ),
        }
        # This is a hack. Technically, spec decoding should compute
        # num_lookahead slots at one shot, but instead, it expands the batch
        # and evaluate one by one right now. context_len is seq_len - 1 because
        # the kv cache is filled by a previous batch in the batch expansion.
        for data in new_seq_data_dict.values():
            data.update_num_computed_tokens(data.get_len() - 1)
            data.mrope_position_delta = mrope_position_delta

        return SequenceGroupMetadata(
            request_id=seq_group_metadata.request_id,
            is_prompt=seq_group_metadata.is_prompt,
            seq_data=new_seq_data_dict,
            sampling_params=sampling_params,
            block_tables={
                target_seq_id: seq_group_metadata.block_tables[seq_id],
            },
            lora_request=None,
            token_chunk_size=1,
        )

    @staticmethod
    def _split_scoring_output(
        sampler_output: SamplerOutput, num_scoring_tokens: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               Optional[torch.Tensor], torch.Tensor, torch.Tensor,
               torch.Tensor, Optional[torch.Tensor]]:
        """Split the target model output into speculative and non-speculative
        output.
        """

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        #
        # First samples are non-speculative, latter samples are from speculative
        # scoring (prefill|decode order).
        split_sizes = (sampler_output.sampled_token_ids.numel() -
                       num_scoring_tokens, num_scoring_tokens)
        (non_spec_probs,
         spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
        (non_spec_sampled_tokens, spec_sampled_tokens
         ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
        (non_spec_logprobs,
         spec_logprobs) = sampler_output.logprobs.split(split_sizes)

        if sampler_output.hidden_states is not None:
            (non_spec_hidden_states, spec_hidden_states
             ) = sampler_output.hidden_states.split(split_sizes)
        else:
            non_spec_hidden_states, spec_hidden_states = None, None

        return (spec_sampled_tokens, spec_probs, spec_logprobs,
                spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
                non_spec_logprobs, non_spec_hidden_states)

    @staticmethod
    def _create_target_seq_id_iterator(
            seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
        """Create an iterator for creating target sequence ids.
        Target sequence ids are distinct from sequence ids because we create a
        distinct target sequence id for each proposal token to be scored.

        This implementation increments a counter starting at 1 + max of all
        provided input sequence ids.
        """
        return count(start=max(seq_ids) + 1)

    @staticmethod
    def _get_token_ids_to_score(
        full_spec_token_ids: List[TokenId]  # shape: [k]
    ) -> List[List[TokenId]]:
        """Given an int tensor of proposal token ids, return a list of
        token ids that should be scored.

        Returns k+1 output lists. The additional one is used for generating the
        bonus token.

        Example:
            Input: [0, 1, 2, 3] (k=4)
            Output: (k+1 lists)
                []
                [0]
                [0, 1]
                [0, 1, 2]
                [0, 1, 2, 3]
        """
        empty_token_ids: List[TokenId] = []

        token_ids_to_score = [empty_token_ids]
        token_ids_to_score.extend(full_spec_token_ids[:i + 1]
                                  for i in range(len(full_spec_token_ids)))
        return token_ids_to_score

_contract_batch

_contract_batch(
    contracted_seq_group_metadata_list: List[
        SequenceGroupMetadata
    ],
    target_sampler_output: SamplerOutput,
    proposals: SpeculativeProposals,
    num_scoring_tokens: int,
    non_spec_indices: List[int],
    spec_indices: List[int],
    k: int,
) -> SpeculativeScores

Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences.

contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to.

Source code in vllm/spec_decode/batch_expansion.py
def _contract_batch(
        self,
        contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
        target_sampler_output: SamplerOutput,
        proposals: SpeculativeProposals, num_scoring_tokens: int,
        non_spec_indices: List[int], spec_indices: List[int],
        k: int) -> SpeculativeScores:
    """Contract the expanded batch back into its original size.
    This maps the scores of speculative tokens back to their original
    sequences.

    contracted_bs is the original batch size, and the batch size that the
    target_sampler_output will be contracted to.
    """
    contracted_bs = len(contracted_seq_group_metadata_list)
    (target_token_ids, target_probs, target_logprobs, target_hidden_states,
     non_spec_target_token_ids, non_spec_target_probs,
     non_spec_target_logprobs,
     non_spec_target_hidden_states) = self._split_scoring_output(
         target_sampler_output, num_scoring_tokens)

    # Map distinct sequences used to score each token
    # of shape [batch_size * k + 1] back to [batch_size, k + 1].
    expanded_batch_size, k = proposals.proposal_token_ids.shape

    # The number of tokens in the expanded batch used for speculation is
    # equal to the total expanded batch size minus the number of samples for
    # non-speculative sequences, prefill chunks with no out tokens included
    non_spec_expanded_bs = len(non_spec_indices)
    spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs

    target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
    target_probs = target_probs.reshape(*target_token_ids.shape,
                                        self._vocab_size)
    target_logprobs = target_logprobs.reshape(target_probs.shape)

    if target_hidden_states is not None:
        target_hidden_states = target_hidden_states.reshape(
            *target_token_ids.shape, target_hidden_states.shape[-1])

    all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
                                           fill_value=-1)
    all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
    all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                            fill_value=-float("inf"))

    if target_sampler_output.hidden_states is not None:
        all_hidden_states = target_hidden_states.new_zeros(
            size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
    else:
        all_hidden_states = None

    has_prompt_log = any((sg.sampling_params.prompt_logprobs
                          and sg.sampling_params.prompt_logprobs > 0)
                         for sg in contracted_seq_group_metadata_list)
    # When prompt logprobs is enabled, lens of returned tensors go from
    # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
    # We adjust stride accordingly to get the generated tokens and
    # their probs, but pass on prompt_logprobs as is.
    prompt_logprobs = None
    if (not self._scorer_worker.model_runner.disable_logprobs\
        and has_prompt_log):
        prompt_logprobs = [
            o.prompt_logprobs for o in target_sampler_output.outputs
        ]
    elif not has_prompt_log:
        # When prompt logprobs are not to be returned,
        # we can ignore non-terminal chunks (no out token).
        non_spec_indices = [
            idx for idx in non_spec_indices
            if contracted_seq_group_metadata_list[idx].do_sample
        ]

    # "Contract" speculative.
    if spec_indices:
        all_tokens[spec_indices] = target_token_ids
        all_probs[spec_indices] = target_probs
        all_logprobs[spec_indices] = target_logprobs
        if all_hidden_states is not None:
            all_hidden_states[spec_indices] = target_hidden_states

    spec_scores = SpeculativeScores(probs=all_probs,
                                    token_ids=all_tokens,
                                    logprobs=all_logprobs,
                                    hidden_states=all_hidden_states,
                                    prompt_logprobs=prompt_logprobs)

    non_spec_outputs = SpeculativeScores(
        probs=non_spec_target_probs,
        token_ids=non_spec_target_token_ids,
        logprobs=non_spec_target_logprobs,
        hidden_states=non_spec_target_hidden_states)
    # Contract remaining nonspec entries based on non_spec_indices, if any.
    return self._contract_non_speculative(
        spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
        non_spec_outputs, has_prompt_log)

_contract_batch_all_spec

_contract_batch_all_spec(
    target_sampler_output: SamplerOutput,
    proposals: SpeculativeProposals,
) -> SpeculativeScores

Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences.

It assumes all sequences in the batch were previously expanded.

Source code in vllm/spec_decode/batch_expansion.py
def _contract_batch_all_spec(
    self,
    target_sampler_output: SamplerOutput,
    proposals: SpeculativeProposals,
) -> SpeculativeScores:
    """Contract the expanded batch back into its original size.
    This maps the scores of speculative tokens back to their original
    sequences.

    It assumes all sequences in the batch were previously expanded.
    """

    # Map distinct sequences used to score each token
    # of shape [batch_size * k + 1] back to [batch_size, k + 1].
    contracted_bs, k = proposals.proposal_token_ids.shape

    # Reshape tensors to original batch size
    target_token_ids = target_sampler_output.sampled_token_ids.reshape(
        contracted_bs, k + 1)
    target_probs = target_sampler_output.sampled_token_probs.reshape(
        *target_token_ids.shape, self._vocab_size)
    target_logprobs = target_sampler_output.logprobs.reshape(
        target_probs.shape)
    target_hidden_states = target_sampler_output.hidden_states
    if target_hidden_states is not None:
        target_hidden_states = target_hidden_states.reshape(
            *target_token_ids.shape, target_hidden_states.shape[-1])

    return SpeculativeScores(probs=target_probs,
                             token_ids=target_token_ids,
                             logprobs=target_logprobs,
                             hidden_states=target_hidden_states,
                             prompt_logprobs=None)

_contract_non_speculative

_contract_non_speculative(
    scores: SpeculativeScores,
    seq_group_metadata_list: List[SequenceGroupMetadata],
    non_spec_indices: List[int],
    non_spec_outputs: SpeculativeScores,
    has_prompt_log: bool,
) -> SpeculativeScores

Augment input scores with non-speculative requests outputs. This includes decode requests with speculation turned off, as well as prefill requests when enable_chunked_prefill is set. For the latter, prefills are further separated into terminal and non-terminal chunks (from which no token is sampled).

Source code in vllm/spec_decode/batch_expansion.py
def _contract_non_speculative(
        self, scores: SpeculativeScores,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
        has_prompt_log: bool) -> SpeculativeScores:
    """
        Augment input `scores` with non-speculative requests outputs. 
        This includes decode requests with speculation turned off, as well
        as prefill requests when `enable_chunked_prefill` is set.
        For the latter, prefills are further separated into terminal and 
        non-terminal chunks (from which no token is sampled).
    """
    if not non_spec_indices:
        return scores

    if has_prompt_log:
        # When prompt_logprobs is enabled, prefills yield output token
        # (and respective prob) in the last entry (prompt|out):
        # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
        # With chunked prefill, non-terminal chunks have -1 on each
        # position: they're still picked, but they're discarded later.
        seq_meta = seq_group_metadata_list
        nospec_sizes = torch.tensor([
            seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
            for i in non_spec_indices
        ])
        nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
    else:
        # In this case only sampled tokens are returned, select all.
        nospec_sampled_token_idxs = list(
            range(len(non_spec_outputs.token_ids)))

    scores.token_ids[non_spec_indices, :1] = \
        non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
    scores.probs[non_spec_indices, :1, :] = \
        non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
    scores.logprobs[non_spec_indices, :1, :] = \
        non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
    if scores.hidden_states is not None:
        assert non_spec_outputs.hidden_states is not None
        scores.hidden_states[non_spec_indices, :1, :] = \
            non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
    return scores

_create_scoring_model_input

_create_scoring_model_input(
    seq_group_metadata_list: List[SequenceGroupMetadata],
    proposal_token_ids: List[List[TokenId]],
    target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]

Given the original input sequences and proposed tokens from the draft model, create a list of target sequences that can be used for scoring.

target_seq_ids_iter provides sequence ids for the expanded batch, fulfilling the requirement that no seq id in the expanded batch is equal to the seq id in the original batch.

Source code in vllm/spec_decode/batch_expansion.py
def _create_scoring_model_input(
    self,
    seq_group_metadata_list: List[SequenceGroupMetadata],
    proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
    target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
    """Given the original input sequences and proposed tokens from the draft
    model, create a list of target sequences that can be used for scoring.

    target_seq_ids_iter provides sequence ids for the expanded batch,
    fulfilling the requirement that no seq id in the expanded batch is equal
    to the seq id in the original batch.
    """

    if not seq_group_metadata_list:
        return []

    target_seq_group_metadata = list(
        chain.from_iterable(
            self._create_target_seq_group_metadata(
                seq_group_metadata,
                proposal_token_ids,
                i,
                target_seq_ids_iter,
            ) for i, seq_group_metadata in enumerate(
                seq_group_metadata_list)))

    return target_seq_group_metadata

_create_single_target_seq_group_metadata staticmethod

_create_single_target_seq_group_metadata(
    seq_group_metadata: SequenceGroupMetadata,
    seq_id: SeqId,
    target_seq_id: TargetSeqId,
    token_ids: List[TokenId],
    sampling_params: SamplingParams,
) -> SequenceGroupMetadata

Create a single target SequenceGroupMetadata.

Parameters:

Name Type Description Default
seq_group_metadata SequenceGroupMetadata

The metadata for the input sequence.

required
seq_id SeqId

The input sequence ID.

required
target_seq_id TargetSeqId

The corresponding target sequence ID.

required
token_ids List[TokenId]

The list of token ids that are to be appended to the input sequence.

required
Source code in vllm/spec_decode/batch_expansion.py
@staticmethod
def _create_single_target_seq_group_metadata(
    seq_group_metadata: SequenceGroupMetadata,
    seq_id: SeqId,
    target_seq_id: TargetSeqId,
    token_ids: List[TokenId],
    sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
    """Create a single target SequenceGroupMetadata.

    Args:
        seq_group_metadata: The metadata for the input sequence.
        seq_id: The input sequence ID.
        target_seq_id: The corresponding target sequence ID.
        token_ids: The list of token ids that are to be appended to the
            input sequence.
    """
    seq_data = seq_group_metadata.seq_data[seq_id]
    prompt_token_ids = seq_data.prompt_token_ids_array
    new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
    mrope_position_delta = seq_data.mrope_position_delta

    new_seq_data_dict = {
        target_seq_id:
        SequenceData(
            prompt_token_ids,
            _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                    new_output_token_ids),
        ),
    }
    # This is a hack. Technically, spec decoding should compute
    # num_lookahead slots at one shot, but instead, it expands the batch
    # and evaluate one by one right now. context_len is seq_len - 1 because
    # the kv cache is filled by a previous batch in the batch expansion.
    for data in new_seq_data_dict.values():
        data.update_num_computed_tokens(data.get_len() - 1)
        data.mrope_position_delta = mrope_position_delta

    return SequenceGroupMetadata(
        request_id=seq_group_metadata.request_id,
        is_prompt=seq_group_metadata.is_prompt,
        seq_data=new_seq_data_dict,
        sampling_params=sampling_params,
        block_tables={
            target_seq_id: seq_group_metadata.block_tables[seq_id],
        },
        lora_request=None,
        token_chunk_size=1,
    )

_create_target_seq_group_metadata

_create_target_seq_group_metadata(
    input_seq_group_metadata: SequenceGroupMetadata,
    proposal_token_ids: List[List[TokenId]],
    batch_index: int,
    target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]

Given an input sequence group metadata and a list of draft tokens, create a list of target SequenceGroupMetadata, one for each token id that needs to be scored.

Naive speculative decoding requires K target model scores, one for each draft model token. However one can add a bonus token such that if each token is accepted, then a final token may be sampled from the model. This function creates K+1 target SequenceGroupMetadata to take advantage of the bonus token.

Source code in vllm/spec_decode/batch_expansion.py
def _create_target_seq_group_metadata(
    self,
    input_seq_group_metadata: SequenceGroupMetadata,
    proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
    batch_index: int,
    target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
    """Given an input sequence group metadata and a list of draft tokens,
    create a list of target SequenceGroupMetadata, one for each
    token id that needs to be scored.

    Naive speculative decoding requires K target model scores, one for each
    draft model token. However one can add a bonus token such that if each
    token is accepted, then a final token may be sampled from the model.
    This function creates K+1 target SequenceGroupMetadata to take
    advantage of the bonus token.
    """
    assert len(input_seq_group_metadata.seq_data) == 1, (
        "Beam search "
        "not supported in speculative decoding")
    input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))

    token_ids_to_score = self._get_token_ids_to_score(
        proposal_token_ids[batch_index])

    sampling_params = input_seq_group_metadata.sampling_params
    target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
    for i, token_ids in enumerate(token_ids_to_score):
        target_seq_group_metadata_list.append(
            self._create_single_target_seq_group_metadata(
                input_seq_group_metadata,
                input_seq_id,
                next(target_seq_ids_iter),
                token_ids,
                sampling_params=sampling_params,
            ))

    return target_seq_group_metadata_list

_create_target_seq_id_iterator staticmethod

_create_target_seq_id_iterator(
    seq_ids: List[SeqId],
) -> Iterator[TargetSeqId]

Create an iterator for creating target sequence ids. Target sequence ids are distinct from sequence ids because we create a distinct target sequence id for each proposal token to be scored.

This implementation increments a counter starting at 1 + max of all provided input sequence ids.

Source code in vllm/spec_decode/batch_expansion.py
@staticmethod
def _create_target_seq_id_iterator(
        seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
    """Create an iterator for creating target sequence ids.
    Target sequence ids are distinct from sequence ids because we create a
    distinct target sequence id for each proposal token to be scored.

    This implementation increments a counter starting at 1 + max of all
    provided input sequence ids.
    """
    return count(start=max(seq_ids) + 1)

_expand_batch

_expand_batch(
    seq_group_metadata_list: List[SequenceGroupMetadata],
    proposal_token_ids_list: List[List[TokenId]],
    proposal_lens_list: List[int],
) -> Tuple[
    List[int], List[int], List[SequenceGroupMetadata], int
]

Given the input sequences and potentially multiple corresponding proposal tokens, create a new batch where each sequence has a single query token.

Source code in vllm/spec_decode/batch_expansion.py
def _expand_batch(
    self,
    seq_group_metadata_list: List[SequenceGroupMetadata],
    proposal_token_ids_list: List[List[TokenId]],
    proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
    """Given the input sequences and potentially multiple corresponding
    proposal tokens, create a new batch where each sequence has a single
    query token.
    """

    # vLLM currently only supports proposal lens equal to zero or the batch
    # proposal len. This adds some complexity (splitting the batch into spec
    # and non spec sequences) and should be removed in the future. It can be
    # done by supporting per-sequence proposal lens.
    (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
        split_batch_by_proposal_len(
            seq_group_metadata_list, proposal_lens_list)

    spec_expanded_seqs = self._create_scoring_model_input(
        seq_group_metadata_list=spec_seqs,
        proposal_token_ids=proposal_token_ids_list,
        # NOTE: We determine the seq ids in the expanded batch using the
        # full seq_group_metadata_list, instead of only spec_seqs.
        target_seq_ids_iter=self._create_target_seq_id_iterator(
            seq_ids=get_all_seq_ids(seq_group_metadata_list)),
    )

    num_scoring_tokens = len(spec_expanded_seqs)
    # Batch speculative and non-speculative (e.g. chunked prefill) requests
    # but make sure order is prefill|decode due to backend requirement.
    target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs

    return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
            num_scoring_tokens)

_get_token_ids_to_score staticmethod

_get_token_ids_to_score(
    full_spec_token_ids: List[TokenId],
) -> List[List[TokenId]]

Given an int tensor of proposal token ids, return a list of token ids that should be scored.

Returns k+1 output lists. The additional one is used for generating the bonus token.

Example

Input: [0, 1, 2, 3] (k=4) Output: (k+1 lists) [] [0] [0, 1] [0, 1, 2] [0, 1, 2, 3]

Source code in vllm/spec_decode/batch_expansion.py
@staticmethod
def _get_token_ids_to_score(
    full_spec_token_ids: List[TokenId]  # shape: [k]
) -> List[List[TokenId]]:
    """Given an int tensor of proposal token ids, return a list of
    token ids that should be scored.

    Returns k+1 output lists. The additional one is used for generating the
    bonus token.

    Example:
        Input: [0, 1, 2, 3] (k=4)
        Output: (k+1 lists)
            []
            [0]
            [0, 1]
            [0, 1, 2]
            [0, 1, 2, 3]
    """
    empty_token_ids: List[TokenId] = []

    token_ids_to_score = [empty_token_ids]
    token_ids_to_score.extend(full_spec_token_ids[:i + 1]
                              for i in range(len(full_spec_token_ids)))
    return token_ids_to_score

_split_scoring_output staticmethod

_split_scoring_output(
    sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[
    Tensor,
    Tensor,
    Tensor,
    Optional[Tensor],
    Tensor,
    Tensor,
    Tensor,
    Optional[Tensor],
]

Split the target model output into speculative and non-speculative output.

Source code in vllm/spec_decode/batch_expansion.py
@staticmethod
def _split_scoring_output(
    sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
           Optional[torch.Tensor], torch.Tensor, torch.Tensor,
           torch.Tensor, Optional[torch.Tensor]]:
    """Split the target model output into speculative and non-speculative
    output.
    """

    # vLLM currently only supports proposal lens equal to zero or the batch
    # proposal len. This adds some complexity (splitting the batch into spec
    # and non spec sequences) and should be removed in the future. It can be
    # done by supporting per-sequence proposal lens.
    #
    # First samples are non-speculative, latter samples are from speculative
    # scoring (prefill|decode order).
    split_sizes = (sampler_output.sampled_token_ids.numel() -
                   num_scoring_tokens, num_scoring_tokens)
    (non_spec_probs,
     spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
    (non_spec_sampled_tokens, spec_sampled_tokens
     ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
    (non_spec_logprobs,
     spec_logprobs) = sampler_output.logprobs.split(split_sizes)

    if sampler_output.hidden_states is not None:
        (non_spec_hidden_states, spec_hidden_states
         ) = sampler_output.hidden_states.split(split_sizes)
    else:
        non_spec_hidden_states, spec_hidden_states = None, None

    return (spec_sampled_tokens, spec_probs, spec_logprobs,
            spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
            non_spec_logprobs, non_spec_hidden_states)

score_proposals

score_proposals(
    execute_model_req: ExecuteModelRequest,
    proposals: SpeculativeProposals,
) -> SpeculativeScores

Score the proposed tokens via the scorer model.

This converts each input sequence to a set of k+1 target sequences. The target sequences have the unique continuations to be scored and a unique sequence ID that is different from all input sequence ids.

If a speculative sequence length would exceed the max model length, then no speculation is produced for that sequence.

Parameters:

Name Type Description Default
execute_model_req ExecuteModelRequest

The execution request.

required
proposals SpeculativeProposals

The speculative proposals to score.

required

Returns: SpeculativeScores: The scores of each speculative token, along with which sequences were ignored during scoring.

Source code in vllm/spec_decode/batch_expansion.py
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
    self,
    execute_model_req: ExecuteModelRequest,
    proposals: SpeculativeProposals,
) -> SpeculativeScores:
    """Score the proposed tokens via the scorer model.

    This converts each input sequence to a set of k+1 target sequences. The
    target sequences have the unique continuations to be scored and a
    unique sequence ID that is different from all input sequence ids.

    If a speculative sequence length would exceed the max model length, then
    no speculation is produced for that sequence.

    Args:
        execute_model_req: The execution request.
        proposals: The speculative proposals to score.
    Returns:
        SpeculativeScores: The scores of each speculative token, along with
            which sequences were ignored during scoring.
    """

    # TODO(cade) perform this on GPU to remove blocking call.
    proposal_lens_list = proposals.proposal_lens.tolist()
    proposal_token_ids_list = proposals.proposal_token_ids.tolist()

    # Filter the list to ignore invalid proposals.
    proposal_token_ids_list_without_skips = [
        proposals for proposals in proposal_token_ids_list
        if VLLM_INVALID_TOKEN_ID not in proposals
    ]

    (spec_indices, non_spec_indices, target_seq_group_metadata_list,
     num_scoring_tokens) = self._expand_batch(
         seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
         proposal_token_ids_list=proposal_token_ids_list_without_skips,
         proposal_lens_list=proposal_lens_list,
     )

    target_sampler_output = self._scorer_worker.execute_model(
        execute_model_req=execute_model_req.clone(
            seq_group_metadata_list=target_seq_group_metadata_list))
    assert len(target_sampler_output) == 1, "expected single-step output"
    target_sampler_output = target_sampler_output[0]

    if not non_spec_indices:
        # All sequence groups in batch have spec decoding enabled
        return self._contract_batch_all_spec(
            target_sampler_output=target_sampler_output,
            proposals=proposals,
        )
    else:
        # Batch has a mix of spec decode enabled and disabled seq groups
        return self._contract_batch(
            execute_model_req.seq_group_metadata_list,
            target_sampler_output=target_sampler_output,
            proposals=proposals,
            num_scoring_tokens=num_scoring_tokens,
            non_spec_indices=non_spec_indices,
            spec_indices=spec_indices,
            k=execute_model_req.num_lookahead_slots,
        )