Skip to content

vllm.v1.engine.logprobs

NONES module-attribute

NONES = repeat(None)

logger module-attribute

logger = init_logger(__name__)

LogprobsProcessor dataclass

Source code in vllm/v1/engine/logprobs.py
@dataclass
class LogprobsProcessor:

    # Tokenizer for this request,
    # None if detokenization is disabled.
    tokenizer: Optional[AnyTokenizer]

    # Logprobs for this request
    logprobs: Optional[SampleLogprobs]
    prompt_logprobs: Optional[PromptLogprobs]
    cumulative_logprob: Optional[float]
    num_logprobs: Optional[int]
    num_prompt_logprobs: Optional[int]

    @classmethod
    def from_new_request(
        cls,
        tokenizer: Optional[AnyTokenizer],
        request: EngineCoreRequest,
    ) -> "LogprobsProcessor":
        assert request.sampling_params is not None
        num_logprobs = request.sampling_params.logprobs
        num_prompt_logprobs = request.sampling_params.prompt_logprobs
        return cls(
            tokenizer=tokenizer,
            cumulative_logprob=(None if num_logprobs is None else 0.),
            logprobs=(None if num_logprobs is None else []),
            # NOTE: logprob of first prompt token is None.
            prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
            num_prompt_logprobs=num_prompt_logprobs,
            num_logprobs=num_logprobs,
        )

    def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
        """Update with sample logprobs from EngineCore.

        Outer lists are only of len > 1 if EngineCore made
        >1 tokens in prior step (e.g. in spec decoding).

        Args:
          logprobs_lists: the lists of logprob tokens, logprobs, and ranks.

        """

        assert self.num_logprobs is not None
        assert self.logprobs is not None
        assert self.cumulative_logprob is not None

        token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists

        for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
                                             token_ids_lst):

            # Detokenize (non-incrementally).
            decoded_tokens = NONES if self.tokenizer is None else (
                convert_ids_list_to_tokens(self.tokenizer, token_ids))

            # Sampler puts the sampled logprob in first.
            sampled_token_logprob = logprobs[0]
            self.cumulative_logprob += sampled_token_logprob

            # Update with the Logprob dictionary for this pos.
            self.logprobs.append(
                self._make_logprob_dict(
                    logprobs,
                    token_ids,
                    decoded_tokens,
                    rank,
                    self.num_logprobs,
                ))

    def _update_prompt_logprobs(
        self,
        prompt_logprobs_tensors: LogprobsTensors,
    ) -> None:
        """Update with prompt logprobs from EngineCore.

        Args:
          prompt_logprobs_tensors: tuple containing the prompt logprobs
                                   tensors.

        """

        # Prompt logprobs are enabled.
        assert self.num_prompt_logprobs is not None
        assert self.prompt_logprobs is not None

        token_ids, logprobs, ranks = prompt_logprobs_tensors

        # Detokenize non-incrementally.
        # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
        decoded_tokens = None if self.tokenizer is None else (
            convert_ids_list_to_tokens(self.tokenizer,
                                       token_ids.flatten().tolist()))

        # Recover shapes.
        num_prompt_tokens, num_logprobs = logprobs.shape

        # Pythonize the torch tensors.
        prompt_token_ranks = ranks.tolist()
        prompt_logprobs = logprobs.tolist()
        token_ids = token_ids.tolist()

        # Make Logprob for each position.
        for pos in range(num_prompt_tokens):
            # Handle flattening.
            offset = pos * num_logprobs
            offset_end = offset + num_logprobs
            decoded_tokens_for_pos = NONES \
            if decoded_tokens is None else decoded_tokens[offset:offset_end]

            # Update with the Logprob dictionary for this pos.
            self.prompt_logprobs.append(
                self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
                                        decoded_tokens_for_pos,
                                        prompt_token_ranks[pos],
                                        self.num_prompt_logprobs))

    def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
        """Pop and return all request prompt logprobs

        The logprobs processor aggregates prompt chunk logprobs
        over one or more prefill chunks. This method returns
        all prompt logprobs at once and then forgets them.
        Ensures correct RequestOutputKind.DELTA semantics
        wherein all prompt logprobs are returned at once at
        the end of prefill.

        Returns:
          None if prompt logprobs are disabled for this request.
          List of all prompt logprobs, otherwise.
        """
        plp = self.prompt_logprobs
        if plp:
            self.prompt_logprobs = []
        return plp

    @staticmethod
    def _make_logprob_dict(
        logprobs: list[float],
        logprob_token_ids: list[int],
        decoded_tokens: Iterable[Optional[str]],
        rank: int,
        num_logprobs: int,
    ) -> dict[int, Logprob]:
        """Make a Logprob dictionary for a position.

        Args:
          logprobs: list of log probabilities
          logprob_token_ids: list of top token ids
          decoded_tokens: list of decoded top tokens
          rank: rank of the sampled token
          num_logprobs: number of logprobs requested
            by the user (in addition to sampled logprob)

        Returns:
          dict[token id, Logprob]
        """

        # We do not need a special case for the sampled token
        # being in the topk, since inserting duplicated data
        # into a dictionary twice is the same as doing it once.
        topk_ranks = range(1, num_logprobs + 1)
        ranks = itertools.chain((rank, ), topk_ranks)

        return {
            token_id: Logprob(
                logprob=logprob,
                rank=rank,
                decoded_token=token,
            )
            for token_id, logprob, rank, token in zip(
                logprob_token_ids, logprobs, ranks, decoded_tokens)
        }

    def update_from_output(self, output: EngineCoreOutput) -> None:
        if output.new_logprobs is not None:
            self._update_sample_logprobs(output.new_logprobs)
        if output.new_prompt_logprobs_tensors is not None:
            self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)

cumulative_logprob instance-attribute

cumulative_logprob: Optional[float]

logprobs instance-attribute

num_logprobs instance-attribute

num_logprobs: Optional[int]

num_prompt_logprobs instance-attribute

num_prompt_logprobs: Optional[int]

prompt_logprobs instance-attribute

prompt_logprobs: Optional[PromptLogprobs]

tokenizer instance-attribute

tokenizer: Optional[AnyTokenizer]

__init__

__init__(
    tokenizer: Optional[AnyTokenizer],
    logprobs: Optional[SampleLogprobs],
    prompt_logprobs: Optional[PromptLogprobs],
    cumulative_logprob: Optional[float],
    num_logprobs: Optional[int],
    num_prompt_logprobs: Optional[int],
) -> None

_make_logprob_dict staticmethod

_make_logprob_dict(
    logprobs: list[float],
    logprob_token_ids: list[int],
    decoded_tokens: Iterable[Optional[str]],
    rank: int,
    num_logprobs: int,
) -> dict[int, Logprob]

Make a Logprob dictionary for a position.

Parameters:

Name Type Description Default
logprobs list[float]

list of log probabilities

required
logprob_token_ids list[int]

list of top token ids

required
decoded_tokens Iterable[Optional[str]]

list of decoded top tokens

required
rank int

rank of the sampled token

required
num_logprobs int

number of logprobs requested by the user (in addition to sampled logprob)

required

Returns:

Type Description
dict[int, Logprob]

dict[token id, Logprob]

Source code in vllm/v1/engine/logprobs.py
@staticmethod
def _make_logprob_dict(
    logprobs: list[float],
    logprob_token_ids: list[int],
    decoded_tokens: Iterable[Optional[str]],
    rank: int,
    num_logprobs: int,
) -> dict[int, Logprob]:
    """Make a Logprob dictionary for a position.

    Args:
      logprobs: list of log probabilities
      logprob_token_ids: list of top token ids
      decoded_tokens: list of decoded top tokens
      rank: rank of the sampled token
      num_logprobs: number of logprobs requested
        by the user (in addition to sampled logprob)

    Returns:
      dict[token id, Logprob]
    """

    # We do not need a special case for the sampled token
    # being in the topk, since inserting duplicated data
    # into a dictionary twice is the same as doing it once.
    topk_ranks = range(1, num_logprobs + 1)
    ranks = itertools.chain((rank, ), topk_ranks)

    return {
        token_id: Logprob(
            logprob=logprob,
            rank=rank,
            decoded_token=token,
        )
        for token_id, logprob, rank, token in zip(
            logprob_token_ids, logprobs, ranks, decoded_tokens)
    }

_update_prompt_logprobs

_update_prompt_logprobs(
    prompt_logprobs_tensors: LogprobsTensors,
) -> None

Update with prompt logprobs from EngineCore.

Parameters:

Name Type Description Default
prompt_logprobs_tensors LogprobsTensors

tuple containing the prompt logprobs tensors.

required
Source code in vllm/v1/engine/logprobs.py
def _update_prompt_logprobs(
    self,
    prompt_logprobs_tensors: LogprobsTensors,
) -> None:
    """Update with prompt logprobs from EngineCore.

    Args:
      prompt_logprobs_tensors: tuple containing the prompt logprobs
                               tensors.

    """

    # Prompt logprobs are enabled.
    assert self.num_prompt_logprobs is not None
    assert self.prompt_logprobs is not None

    token_ids, logprobs, ranks = prompt_logprobs_tensors

    # Detokenize non-incrementally.
    # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
    decoded_tokens = None if self.tokenizer is None else (
        convert_ids_list_to_tokens(self.tokenizer,
                                   token_ids.flatten().tolist()))

    # Recover shapes.
    num_prompt_tokens, num_logprobs = logprobs.shape

    # Pythonize the torch tensors.
    prompt_token_ranks = ranks.tolist()
    prompt_logprobs = logprobs.tolist()
    token_ids = token_ids.tolist()

    # Make Logprob for each position.
    for pos in range(num_prompt_tokens):
        # Handle flattening.
        offset = pos * num_logprobs
        offset_end = offset + num_logprobs
        decoded_tokens_for_pos = NONES \
        if decoded_tokens is None else decoded_tokens[offset:offset_end]

        # Update with the Logprob dictionary for this pos.
        self.prompt_logprobs.append(
            self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
                                    decoded_tokens_for_pos,
                                    prompt_token_ranks[pos],
                                    self.num_prompt_logprobs))

_update_sample_logprobs

_update_sample_logprobs(
    logprobs_lists: LogprobsLists,
) -> None

Update with sample logprobs from EngineCore.

Outer lists are only of len > 1 if EngineCore made

1 tokens in prior step (e.g. in spec decoding).

Parameters:

Name Type Description Default
logprobs_lists LogprobsLists

the lists of logprob tokens, logprobs, and ranks.

required
Source code in vllm/v1/engine/logprobs.py
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
    """Update with sample logprobs from EngineCore.

    Outer lists are only of len > 1 if EngineCore made
    >1 tokens in prior step (e.g. in spec decoding).

    Args:
      logprobs_lists: the lists of logprob tokens, logprobs, and ranks.

    """

    assert self.num_logprobs is not None
    assert self.logprobs is not None
    assert self.cumulative_logprob is not None

    token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists

    for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
                                         token_ids_lst):

        # Detokenize (non-incrementally).
        decoded_tokens = NONES if self.tokenizer is None else (
            convert_ids_list_to_tokens(self.tokenizer, token_ids))

        # Sampler puts the sampled logprob in first.
        sampled_token_logprob = logprobs[0]
        self.cumulative_logprob += sampled_token_logprob

        # Update with the Logprob dictionary for this pos.
        self.logprobs.append(
            self._make_logprob_dict(
                logprobs,
                token_ids,
                decoded_tokens,
                rank,
                self.num_logprobs,
            ))

from_new_request classmethod

from_new_request(
    tokenizer: Optional[AnyTokenizer],
    request: EngineCoreRequest,
) -> LogprobsProcessor
Source code in vllm/v1/engine/logprobs.py
@classmethod
def from_new_request(
    cls,
    tokenizer: Optional[AnyTokenizer],
    request: EngineCoreRequest,
) -> "LogprobsProcessor":
    assert request.sampling_params is not None
    num_logprobs = request.sampling_params.logprobs
    num_prompt_logprobs = request.sampling_params.prompt_logprobs
    return cls(
        tokenizer=tokenizer,
        cumulative_logprob=(None if num_logprobs is None else 0.),
        logprobs=(None if num_logprobs is None else []),
        # NOTE: logprob of first prompt token is None.
        prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
        num_prompt_logprobs=num_prompt_logprobs,
        num_logprobs=num_logprobs,
    )

pop_prompt_logprobs

pop_prompt_logprobs() -> Optional[PromptLogprobs]

Pop and return all request prompt logprobs

The logprobs processor aggregates prompt chunk logprobs over one or more prefill chunks. This method returns all prompt logprobs at once and then forgets them. Ensures correct RequestOutputKind.DELTA semantics wherein all prompt logprobs are returned at once at the end of prefill.

Returns:

Type Description
Optional[PromptLogprobs]

None if prompt logprobs are disabled for this request.

Optional[PromptLogprobs]

List of all prompt logprobs, otherwise.

Source code in vllm/v1/engine/logprobs.py
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
    """Pop and return all request prompt logprobs

    The logprobs processor aggregates prompt chunk logprobs
    over one or more prefill chunks. This method returns
    all prompt logprobs at once and then forgets them.
    Ensures correct RequestOutputKind.DELTA semantics
    wherein all prompt logprobs are returned at once at
    the end of prefill.

    Returns:
      None if prompt logprobs are disabled for this request.
      List of all prompt logprobs, otherwise.
    """
    plp = self.prompt_logprobs
    if plp:
        self.prompt_logprobs = []
    return plp

update_from_output

update_from_output(output: EngineCoreOutput) -> None
Source code in vllm/v1/engine/logprobs.py
def update_from_output(self, output: EngineCoreOutput) -> None:
    if output.new_logprobs is not None:
        self._update_sample_logprobs(output.new_logprobs)
    if output.new_prompt_logprobs_tensors is not None:
        self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)