Skip to content

vllm.spec_decode.mqa_scorer

SeqId module-attribute

SeqId = int

TargetSeqId module-attribute

TargetSeqId = int

MQAScorer

Bases: SpeculativeScorer

Source code in vllm/spec_decode/mqa_scorer.py
class MQAScorer(SpeculativeScorer):

    def score_proposals(
        self,
        execute_model_req: ExecuteModelRequest,
        proposals: SpeculativeProposals,
    ) -> SpeculativeScores:
        target_seq_group_metadata_list = []
        target_seq_id_start = max(
            get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
        all_proposal_tokens = proposals.proposal_token_ids.tolist()
        all_proposal_lengths = proposals.proposal_lens.tolist()
        for i, seq_group_metadata in enumerate(
                execute_model_req.seq_group_metadata_list):
            if all_proposal_lengths[i] == 0:
                # Keep prompt seqs untouched (keep computed_tokens for chunks).
                target_seq_group_metadata_list.append(seq_group_metadata)
                continue

            seq_data_dict = seq_group_metadata.seq_data
            assert len(seq_data_dict) == 1
            seq_id = next(iter(seq_data_dict.keys()))

            seq_data: SequenceData = seq_data_dict[seq_id]
            prompt_token_ids = seq_data.get_prompt_token_ids()
            output_token_ids = seq_data.get_output_token_ids()
            proposal_token_ids = all_proposal_tokens[
                i][:all_proposal_lengths[i]]
            new_output_token_ids = [*output_token_ids, *proposal_token_ids]

            target_seq_id = target_seq_id_start + i
            new_seq_data = SequenceData.from_seqs(
                prompt_token_ids=prompt_token_ids,
                output_token_ids=new_output_token_ids,
            )
            new_seq_data.update_num_computed_tokens(
                len(prompt_token_ids) + len(output_token_ids) - 1)

            # Ensure that the new decode sequence has at least one token.
            assert len(output_token_ids) >= 1
            new_seq_data_dict = {target_seq_id: new_seq_data}

            new_seq_group_metadata = SequenceGroupMetadata(
                request_id=seq_group_metadata.request_id,
                is_prompt=seq_group_metadata.is_prompt,
                seq_data=new_seq_data_dict,
                sampling_params=seq_group_metadata.sampling_params,
                block_tables={
                    target_seq_id: seq_group_metadata.block_tables[seq_id],
                },
                lora_request=None,
            )
            target_seq_group_metadata_list.append(new_seq_group_metadata)

        target_sampler_output = self._scorer_worker.execute_model(
            execute_model_req=execute_model_req.clone(
                seq_group_metadata_list=target_seq_group_metadata_list))

        target_sampler_output = target_sampler_output[0]

        k = execute_model_req.num_lookahead_slots
        bs = len(execute_model_req.seq_group_metadata_list)
        target_token_ids = target_sampler_output.sampled_token_ids
        target_probs = target_sampler_output.sampled_token_probs
        target_logprobs = target_sampler_output.logprobs
        prompt_logprobs = None

        # If all requests have the same number of query tokens, we can avoid
        # the for loop to build output for better performance.
        if min(all_proposal_lengths) == k:
            # Regular decodes only.
            assert all(not sg.is_prompt
                       for sg in target_seq_group_metadata_list
                       if sg.is_prompt)
            bs, _ = proposals.proposal_token_ids.shape
            all_tokens = target_token_ids.reshape(bs, k + 1)
            all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
            all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
        else:
            # We either have decodes with different lens or prefill+decodes.
            all_tokens = target_token_ids.new_full(size=(bs, k + 1),
                                                   fill_value=-1)
            all_probs = target_probs.new_zeros(*all_tokens.shape,
                                               self._vocab_size)
            all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                    fill_value=-float("inf"))
            target_token_ids = target_token_ids.flatten()

            # When prompt logprobs is enabled, lens of returned tensors go from
            # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
            # We adjust stride accordingly to get the generated tokens and
            # their probs, but pass on prompt_logprobs as is, since it may be
            # that n_prompts >> K.
            has_prompt_log = any((sg.sampling_params.prompt_logprobs
                                  and sg.sampling_params.prompt_logprobs > 0)
                                 for sg in target_seq_group_metadata_list)
            # TODO (NickLucche) we should surface `disable_logprobs` as to not
            # break abstraction to get its value.
            if (not self._scorer_worker.model_runner.disable_logprobs\
                and has_prompt_log):
                prompt_logprobs = [
                    o.prompt_logprobs for o in target_sampler_output.outputs
                ]

            # Split loop into prefill|decode for readability.
            start_loc, i = 0, 0
            while i < len(target_seq_group_metadata_list
                          ) and target_seq_group_metadata_list[i].is_prompt:
                seq_meta = target_seq_group_metadata_list[i]
                end_loc = start_loc
                if has_prompt_log:
                    end_loc += seq_meta.token_chunk_size
                elif seq_meta.do_sample:
                    end_loc += 1

                # Skip chunks with no output tokens.
                if seq_meta.do_sample:
                    # Get sampled token (last position in chunk) and its prob.
                    all_tokens[i, 0] = target_token_ids[end_loc - 1]
                    all_probs[i, 0] = target_probs[end_loc - 1]
                    all_logprobs[i, 0] = target_logprobs[end_loc - 1]

                i += 1
                start_loc = end_loc
            # Decodes.
            while i < len(target_seq_group_metadata_list):
                proposed_len, seq_meta = all_proposal_lengths[
                    i], target_seq_group_metadata_list[i]
                output_len = proposed_len + 1
                end_loc = start_loc + output_len
                all_tokens[
                    i, :output_len] = target_token_ids[start_loc:end_loc]
                all_probs[i, :output_len] = target_probs[start_loc:end_loc]
                all_logprobs[
                    i, :output_len] = target_logprobs[start_loc:end_loc]
                start_loc = end_loc
                i += 1

        hidden_states = None
        if target_sampler_output.hidden_states is not None:
            hidden_states = target_sampler_output.hidden_states.reshape(
                bs, (k + 1), -1)

        return SpeculativeScores(probs=all_probs,
                                 token_ids=all_tokens,
                                 logprobs=all_logprobs,
                                 hidden_states=hidden_states,
                                 prompt_logprobs=prompt_logprobs)

score_proposals

score_proposals(
    execute_model_req: ExecuteModelRequest,
    proposals: SpeculativeProposals,
) -> SpeculativeScores
Source code in vllm/spec_decode/mqa_scorer.py
def score_proposals(
    self,
    execute_model_req: ExecuteModelRequest,
    proposals: SpeculativeProposals,
) -> SpeculativeScores:
    target_seq_group_metadata_list = []
    target_seq_id_start = max(
        get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
    all_proposal_tokens = proposals.proposal_token_ids.tolist()
    all_proposal_lengths = proposals.proposal_lens.tolist()
    for i, seq_group_metadata in enumerate(
            execute_model_req.seq_group_metadata_list):
        if all_proposal_lengths[i] == 0:
            # Keep prompt seqs untouched (keep computed_tokens for chunks).
            target_seq_group_metadata_list.append(seq_group_metadata)
            continue

        seq_data_dict = seq_group_metadata.seq_data
        assert len(seq_data_dict) == 1
        seq_id = next(iter(seq_data_dict.keys()))

        seq_data: SequenceData = seq_data_dict[seq_id]
        prompt_token_ids = seq_data.get_prompt_token_ids()
        output_token_ids = seq_data.get_output_token_ids()
        proposal_token_ids = all_proposal_tokens[
            i][:all_proposal_lengths[i]]
        new_output_token_ids = [*output_token_ids, *proposal_token_ids]

        target_seq_id = target_seq_id_start + i
        new_seq_data = SequenceData.from_seqs(
            prompt_token_ids=prompt_token_ids,
            output_token_ids=new_output_token_ids,
        )
        new_seq_data.update_num_computed_tokens(
            len(prompt_token_ids) + len(output_token_ids) - 1)

        # Ensure that the new decode sequence has at least one token.
        assert len(output_token_ids) >= 1
        new_seq_data_dict = {target_seq_id: new_seq_data}

        new_seq_group_metadata = SequenceGroupMetadata(
            request_id=seq_group_metadata.request_id,
            is_prompt=seq_group_metadata.is_prompt,
            seq_data=new_seq_data_dict,
            sampling_params=seq_group_metadata.sampling_params,
            block_tables={
                target_seq_id: seq_group_metadata.block_tables[seq_id],
            },
            lora_request=None,
        )
        target_seq_group_metadata_list.append(new_seq_group_metadata)

    target_sampler_output = self._scorer_worker.execute_model(
        execute_model_req=execute_model_req.clone(
            seq_group_metadata_list=target_seq_group_metadata_list))

    target_sampler_output = target_sampler_output[0]

    k = execute_model_req.num_lookahead_slots
    bs = len(execute_model_req.seq_group_metadata_list)
    target_token_ids = target_sampler_output.sampled_token_ids
    target_probs = target_sampler_output.sampled_token_probs
    target_logprobs = target_sampler_output.logprobs
    prompt_logprobs = None

    # If all requests have the same number of query tokens, we can avoid
    # the for loop to build output for better performance.
    if min(all_proposal_lengths) == k:
        # Regular decodes only.
        assert all(not sg.is_prompt
                   for sg in target_seq_group_metadata_list
                   if sg.is_prompt)
        bs, _ = proposals.proposal_token_ids.shape
        all_tokens = target_token_ids.reshape(bs, k + 1)
        all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
        all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
    else:
        # We either have decodes with different lens or prefill+decodes.
        all_tokens = target_token_ids.new_full(size=(bs, k + 1),
                                               fill_value=-1)
        all_probs = target_probs.new_zeros(*all_tokens.shape,
                                           self._vocab_size)
        all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                fill_value=-float("inf"))
        target_token_ids = target_token_ids.flatten()

        # When prompt logprobs is enabled, lens of returned tensors go from
        # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
        # We adjust stride accordingly to get the generated tokens and
        # their probs, but pass on prompt_logprobs as is, since it may be
        # that n_prompts >> K.
        has_prompt_log = any((sg.sampling_params.prompt_logprobs
                              and sg.sampling_params.prompt_logprobs > 0)
                             for sg in target_seq_group_metadata_list)
        # TODO (NickLucche) we should surface `disable_logprobs` as to not
        # break abstraction to get its value.
        if (not self._scorer_worker.model_runner.disable_logprobs\
            and has_prompt_log):
            prompt_logprobs = [
                o.prompt_logprobs for o in target_sampler_output.outputs
            ]

        # Split loop into prefill|decode for readability.
        start_loc, i = 0, 0
        while i < len(target_seq_group_metadata_list
                      ) and target_seq_group_metadata_list[i].is_prompt:
            seq_meta = target_seq_group_metadata_list[i]
            end_loc = start_loc
            if has_prompt_log:
                end_loc += seq_meta.token_chunk_size
            elif seq_meta.do_sample:
                end_loc += 1

            # Skip chunks with no output tokens.
            if seq_meta.do_sample:
                # Get sampled token (last position in chunk) and its prob.
                all_tokens[i, 0] = target_token_ids[end_loc - 1]
                all_probs[i, 0] = target_probs[end_loc - 1]
                all_logprobs[i, 0] = target_logprobs[end_loc - 1]

            i += 1
            start_loc = end_loc
        # Decodes.
        while i < len(target_seq_group_metadata_list):
            proposed_len, seq_meta = all_proposal_lengths[
                i], target_seq_group_metadata_list[i]
            output_len = proposed_len + 1
            end_loc = start_loc + output_len
            all_tokens[
                i, :output_len] = target_token_ids[start_loc:end_loc]
            all_probs[i, :output_len] = target_probs[start_loc:end_loc]
            all_logprobs[
                i, :output_len] = target_logprobs[start_loc:end_loc]
            start_loc = end_loc
            i += 1

    hidden_states = None
    if target_sampler_output.hidden_states is not None:
        hidden_states = target_sampler_output.hidden_states.reshape(
            bs, (k + 1), -1)

    return SpeculativeScores(probs=all_probs,
                             token_ids=all_tokens,
                             logprobs=all_logprobs,
                             hidden_states=hidden_states,
                             prompt_logprobs=prompt_logprobs)