Skip to content

vllm.v1.spec_decode.metadata

SpecDecodeMetadata dataclass

Source code in vllm/v1/spec_decode/metadata.py
@dataclass
class SpecDecodeMetadata:

    # [num_tokens]
    draft_token_ids: torch.Tensor
    # [batch_size]
    num_draft_tokens: list[int]
    # [batch_size]
    cu_num_draft_tokens: torch.Tensor
    # [num_tokens]
    target_logits_indices: torch.Tensor
    # [batch_size]
    bonus_logits_indices: torch.Tensor
    # [num_tokens + batch_size]
    logits_indices: torch.Tensor

    def __post_init__(self):
        self.max_spec_len = max(self.num_draft_tokens)

    @classmethod
    def make_dummy(
        cls,
        draft_token_ids: list[list[int]],
        device: torch.device,
    ) -> "SpecDecodeMetadata":
        batch_size = len(draft_token_ids)
        num_draft_tokens = [len(ids) for ids in draft_token_ids]
        flattened_draft_token_ids = sum(draft_token_ids, [])
        num_tokens = len(flattened_draft_token_ids)

        draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
                                              dtype=torch.int32,
                                              device=device)
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
        cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
            device)

        target_logits_indices = torch.zeros(num_tokens,
                                            dtype=torch.int32,
                                            device=device)
        bonus_logits_indices = torch.zeros(batch_size,
                                           dtype=torch.int32,
                                           device=device)
        logits_indices = torch.zeros(num_tokens + batch_size,
                                     dtype=torch.int32,
                                     device=device)
        return cls(
            draft_token_ids=draft_token_ids_tensor,
            num_draft_tokens=num_draft_tokens,
            cu_num_draft_tokens=cu_num_draft_tokens_tensor,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )

bonus_logits_indices instance-attribute

bonus_logits_indices: Tensor

cu_num_draft_tokens instance-attribute

cu_num_draft_tokens: Tensor

draft_token_ids instance-attribute

draft_token_ids: Tensor

logits_indices instance-attribute

logits_indices: Tensor

num_draft_tokens instance-attribute

num_draft_tokens: list[int]

target_logits_indices instance-attribute

target_logits_indices: Tensor

__init__

__init__(
    draft_token_ids: Tensor,
    num_draft_tokens: list[int],
    cu_num_draft_tokens: Tensor,
    target_logits_indices: Tensor,
    bonus_logits_indices: Tensor,
    logits_indices: Tensor,
) -> None

__post_init__

__post_init__()
Source code in vllm/v1/spec_decode/metadata.py
def __post_init__(self):
    self.max_spec_len = max(self.num_draft_tokens)

make_dummy classmethod

make_dummy(
    draft_token_ids: list[list[int]], device: device
) -> SpecDecodeMetadata
Source code in vllm/v1/spec_decode/metadata.py
@classmethod
def make_dummy(
    cls,
    draft_token_ids: list[list[int]],
    device: torch.device,
) -> "SpecDecodeMetadata":
    batch_size = len(draft_token_ids)
    num_draft_tokens = [len(ids) for ids in draft_token_ids]
    flattened_draft_token_ids = sum(draft_token_ids, [])
    num_tokens = len(flattened_draft_token_ids)

    draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
                                          dtype=torch.int32,
                                          device=device)
    cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
    cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
        device)

    target_logits_indices = torch.zeros(num_tokens,
                                        dtype=torch.int32,
                                        device=device)
    bonus_logits_indices = torch.zeros(batch_size,
                                       dtype=torch.int32,
                                       device=device)
    logits_indices = torch.zeros(num_tokens + batch_size,
                                 dtype=torch.int32,
                                 device=device)
    return cls(
        draft_token_ids=draft_token_ids_tensor,
        num_draft_tokens=num_draft_tokens,
        cu_num_draft_tokens=cu_num_draft_tokens_tensor,
        target_logits_indices=target_logits_indices,
        bonus_logits_indices=bonus_logits_indices,
        logits_indices=logits_indices,
    )