Skip to content

vllm.v1.spec_decode.ngram_proposer_gpu

GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.

This version uses a fully vectorized approach with unfold and argmax for finding the first match across all sequences in parallel.

NgramGPUKernel

Bases: Module

GPU-accelerated N-gram proposer using fully async tensor operations.

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
@support_torch_compile()
class NgramGPUKernel(nn.Module):
    """GPU-accelerated N-gram proposer using fully async tensor operations."""

    def __init__(
        self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda"
    ):
        super().__init__()

        assert vllm_config.speculative_config is not None
        assert vllm_config.speculative_config.prompt_lookup_min is not None
        assert vllm_config.speculative_config.prompt_lookup_max is not None

        self.min_n = vllm_config.speculative_config.prompt_lookup_min
        self.max_n = vllm_config.speculative_config.prompt_lookup_max
        self.k = vllm_config.speculative_config.num_speculative_tokens
        self.max_model_len = vllm_config.model_config.max_model_len
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.device = device

    def _find_first_and_extract_all_n_parallel(
        self,
        token_ids: torch.Tensor,
        seq_lengths: torch.Tensor,
        min_ngram_len: int,
        max_ngram_len: int,
        num_draft_tokens: int,
    ) -> torch.Tensor:
        """
        Find suffix n-gram matches and extract following tokens.
        Searches for the earliest prior occurrence of the trailing n-gram,
        tries multiple lengths, and picks the longest valid match.

        Args:
            token_ids: Token IDs for each sequence
            seq_lengths: Actual length of each sequence (excluding padding)
            min_ngram_len: Minimum n-gram size to search for (e.g., 2)
            max_ngram_len: Maximum n-gram size to search for (e.g., 5)
            num_draft_tokens: Number of tokens to extract after match (k)

        Returns:
            Draft token predictions; -1 means invalid/no match.
        """
        batch_size = token_ids.shape[0]
        max_seq_len = token_ids.shape[1]
        device = token_ids.device
        num_ngram_sizes = max_ngram_len - min_ngram_len + 1

        # All n-gram sizes to try.
        ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device)
        batch_indices = torch.arange(batch_size, device=device)

        # Earliest match per (sequence, ngram_len); -1 means no match.
        first_match_positions = torch.full(
            (batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device
        )

        for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
            # Sliding windows of size ngram_len; unfold is O(1) view.
            search_windows = token_ids.unfold(1, ngram_len, 1)
            num_windows = search_windows.shape[1]

            # Trailing suffix (last ngram_len tokens) for each sequence.
            suffix_starts = seq_lengths - ngram_len
            suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(
                ngram_len, device=device
            )
            suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))

            # Window matches for each sequence.
            matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)

            # Match must leave room for at least one draft token.
            max_valid_suffix_start = seq_lengths - ngram_len - 1
            window_positions = torch.arange(num_windows, device=device)
            valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
            final_matches = matches & valid_mask

            # Find earliest match (argmax=0 when empty; verify with has_match).
            first_match_idx = torch.argmax(final_matches.int(), dim=1)
            has_match = final_matches[batch_indices, first_match_idx]

            # Store valid match positions (window index = position).
            first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1)

        # Select the longest n-gram with a match.
        best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
        best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx  # Flip back

        # Match position for the best n-gram.
        best_match_pos = first_match_positions[batch_indices, best_ngram_idx]

        # Avoid data-dependent branching.
        has_any_match = best_match_pos >= 0

        # Length of the best matching n-gram.
        best_ngram_lengths = ngram_lengths[best_ngram_idx]

        # Start position right after the matched suffix.
        draft_start = torch.where(
            has_any_match,
            best_match_pos + best_ngram_lengths,
            torch.zeros_like(best_match_pos),
        )
        tokens_available = seq_lengths - draft_start

        # Gather indices for draft tokens.
        draft_indices = draft_start.unsqueeze(1) + torch.arange(
            num_draft_tokens, device=device
        )
        draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1)

        # Extract draft tokens; gather always runs.
        draft_tokens = torch.gather(token_ids, 1, draft_indices)

        # Mask positions beyond available tokens.
        position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0)
        valid_positions = position_indices < tokens_available.unsqueeze(1)

        draft_tokens = torch.where(
            valid_positions,
            draft_tokens,
            torch.full_like(draft_tokens, -1),
        )

        # If no match, mask all positions.
        draft_tokens = torch.where(
            has_any_match.unsqueeze(1),
            draft_tokens,
            torch.full_like(draft_tokens, -1),
        )

        return draft_tokens

    def forward(
        self,
        num_tokens_no_spec: torch.Tensor,
        token_ids_gpu: torch.Tensor,
        combined_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for N-gram proposal using GPU tensor operations.

        Args:
            num_tokens_no_spec: Number of tokens for each sequence [batch_size]
            token_ids_gpu: Token IDs [batch_size, max_len]
            combined_mask: Whether each sequence is valid for spec decode [batch_size]

        Returns:
            draft_tokens: [batch_size, k] on GPU
            num_valid_draft_tokens: [batch_size] int32 on GPU, count of
                leading valid (non -1) tokens per request.
        """

        device = token_ids_gpu.device

        # Infer batch size to preserve dynamic shape.
        actual_batch_size = token_ids_gpu.shape[0]

        # Allocate in forward so torch.compile can optimize.
        # NOTE(patchy): Do NOT pre-allocate this as a buffer
        #               it breaks torch.compile
        draft_tokens = torch.full(
            (actual_batch_size, self.k), -1, dtype=torch.int32, device=device
        )

        results = self._find_first_and_extract_all_n_parallel(
            token_ids_gpu,
            num_tokens_no_spec,
            min_ngram_len=self.min_n,
            max_ngram_len=self.max_n,
            num_draft_tokens=self.k,
        )

        draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)

        # Count leading contiguous valid (non -1) tokens per request.
        is_valid = draft_tokens != -1  # [batch, k]
        cum_valid = is_valid.int().cumsum(dim=1)  # [batch, k]
        positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
        num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)

        return draft_tokens, num_valid_draft_tokens

    def load_model(self, *args, **kwargs):
        """No model to load for N-gram proposer."""
        pass

_find_first_and_extract_all_n_parallel

_find_first_and_extract_all_n_parallel(
    token_ids: Tensor,
    seq_lengths: Tensor,
    min_ngram_len: int,
    max_ngram_len: int,
    num_draft_tokens: int,
) -> Tensor

Find suffix n-gram matches and extract following tokens. Searches for the earliest prior occurrence of the trailing n-gram, tries multiple lengths, and picks the longest valid match.

Parameters:

Name Type Description Default
token_ids Tensor

Token IDs for each sequence

required
seq_lengths Tensor

Actual length of each sequence (excluding padding)

required
min_ngram_len int

Minimum n-gram size to search for (e.g., 2)

required
max_ngram_len int

Maximum n-gram size to search for (e.g., 5)

required
num_draft_tokens int

Number of tokens to extract after match (k)

required

Returns:

Type Description
Tensor

Draft token predictions; -1 means invalid/no match.

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def _find_first_and_extract_all_n_parallel(
    self,
    token_ids: torch.Tensor,
    seq_lengths: torch.Tensor,
    min_ngram_len: int,
    max_ngram_len: int,
    num_draft_tokens: int,
) -> torch.Tensor:
    """
    Find suffix n-gram matches and extract following tokens.
    Searches for the earliest prior occurrence of the trailing n-gram,
    tries multiple lengths, and picks the longest valid match.

    Args:
        token_ids: Token IDs for each sequence
        seq_lengths: Actual length of each sequence (excluding padding)
        min_ngram_len: Minimum n-gram size to search for (e.g., 2)
        max_ngram_len: Maximum n-gram size to search for (e.g., 5)
        num_draft_tokens: Number of tokens to extract after match (k)

    Returns:
        Draft token predictions; -1 means invalid/no match.
    """
    batch_size = token_ids.shape[0]
    max_seq_len = token_ids.shape[1]
    device = token_ids.device
    num_ngram_sizes = max_ngram_len - min_ngram_len + 1

    # All n-gram sizes to try.
    ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device)
    batch_indices = torch.arange(batch_size, device=device)

    # Earliest match per (sequence, ngram_len); -1 means no match.
    first_match_positions = torch.full(
        (batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device
    )

    for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
        # Sliding windows of size ngram_len; unfold is O(1) view.
        search_windows = token_ids.unfold(1, ngram_len, 1)
        num_windows = search_windows.shape[1]

        # Trailing suffix (last ngram_len tokens) for each sequence.
        suffix_starts = seq_lengths - ngram_len
        suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(
            ngram_len, device=device
        )
        suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))

        # Window matches for each sequence.
        matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)

        # Match must leave room for at least one draft token.
        max_valid_suffix_start = seq_lengths - ngram_len - 1
        window_positions = torch.arange(num_windows, device=device)
        valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
        final_matches = matches & valid_mask

        # Find earliest match (argmax=0 when empty; verify with has_match).
        first_match_idx = torch.argmax(final_matches.int(), dim=1)
        has_match = final_matches[batch_indices, first_match_idx]

        # Store valid match positions (window index = position).
        first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1)

    # Select the longest n-gram with a match.
    best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
    best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx  # Flip back

    # Match position for the best n-gram.
    best_match_pos = first_match_positions[batch_indices, best_ngram_idx]

    # Avoid data-dependent branching.
    has_any_match = best_match_pos >= 0

    # Length of the best matching n-gram.
    best_ngram_lengths = ngram_lengths[best_ngram_idx]

    # Start position right after the matched suffix.
    draft_start = torch.where(
        has_any_match,
        best_match_pos + best_ngram_lengths,
        torch.zeros_like(best_match_pos),
    )
    tokens_available = seq_lengths - draft_start

    # Gather indices for draft tokens.
    draft_indices = draft_start.unsqueeze(1) + torch.arange(
        num_draft_tokens, device=device
    )
    draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1)

    # Extract draft tokens; gather always runs.
    draft_tokens = torch.gather(token_ids, 1, draft_indices)

    # Mask positions beyond available tokens.
    position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0)
    valid_positions = position_indices < tokens_available.unsqueeze(1)

    draft_tokens = torch.where(
        valid_positions,
        draft_tokens,
        torch.full_like(draft_tokens, -1),
    )

    # If no match, mask all positions.
    draft_tokens = torch.where(
        has_any_match.unsqueeze(1),
        draft_tokens,
        torch.full_like(draft_tokens, -1),
    )

    return draft_tokens

forward

forward(
    num_tokens_no_spec: Tensor,
    token_ids_gpu: Tensor,
    combined_mask: Tensor,
) -> tuple[Tensor, Tensor]

Forward pass for N-gram proposal using GPU tensor operations.

Parameters:

Name Type Description Default
num_tokens_no_spec Tensor

Number of tokens for each sequence [batch_size]

required
token_ids_gpu Tensor

Token IDs [batch_size, max_len]

required
combined_mask Tensor

Whether each sequence is valid for spec decode [batch_size]

required

Returns:

Name Type Description
draft_tokens Tensor

[batch_size, k] on GPU

num_valid_draft_tokens Tensor

[batch_size] int32 on GPU, count of leading valid (non -1) tokens per request.

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def forward(
    self,
    num_tokens_no_spec: torch.Tensor,
    token_ids_gpu: torch.Tensor,
    combined_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Forward pass for N-gram proposal using GPU tensor operations.

    Args:
        num_tokens_no_spec: Number of tokens for each sequence [batch_size]
        token_ids_gpu: Token IDs [batch_size, max_len]
        combined_mask: Whether each sequence is valid for spec decode [batch_size]

    Returns:
        draft_tokens: [batch_size, k] on GPU
        num_valid_draft_tokens: [batch_size] int32 on GPU, count of
            leading valid (non -1) tokens per request.
    """

    device = token_ids_gpu.device

    # Infer batch size to preserve dynamic shape.
    actual_batch_size = token_ids_gpu.shape[0]

    # Allocate in forward so torch.compile can optimize.
    # NOTE(patchy): Do NOT pre-allocate this as a buffer
    #               it breaks torch.compile
    draft_tokens = torch.full(
        (actual_batch_size, self.k), -1, dtype=torch.int32, device=device
    )

    results = self._find_first_and_extract_all_n_parallel(
        token_ids_gpu,
        num_tokens_no_spec,
        min_ngram_len=self.min_n,
        max_ngram_len=self.max_n,
        num_draft_tokens=self.k,
    )

    draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)

    # Count leading contiguous valid (non -1) tokens per request.
    is_valid = draft_tokens != -1  # [batch, k]
    cum_valid = is_valid.int().cumsum(dim=1)  # [batch, k]
    positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
    num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)

    return draft_tokens, num_valid_draft_tokens

load_model

load_model(*args, **kwargs)

No model to load for N-gram proposer.

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def load_model(self, *args, **kwargs):
    """No model to load for N-gram proposer."""
    pass

NgramProposerGPU

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
class NgramProposerGPU:
    def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
        assert vllm_config.speculative_config is not None
        assert vllm_config.speculative_config.prompt_lookup_min is not None
        assert vllm_config.speculative_config.prompt_lookup_max is not None

        compilation_config = CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["none"],
            splitting_ops=[],
            compile_sizes=[],
            inductor_compile_config={
                "enable_auto_functionalized_v2": False,
                "max_autotune": True,
                "aggressive_fusion": True,
                "triton.autotune_pointwise": True,
                "coordinate_descent_tuning": True,
                "use_mixed_mm": False,
            },
            cudagraph_mode=CUDAGraphMode.NONE,
        )
        model_config = vllm_config.model_config
        speculative_config = vllm_config.speculative_config
        scheduler_config = vllm_config.scheduler_config

        self.vllm_config = VllmConfig(
            compilation_config=compilation_config,
            model_config=model_config,
            speculative_config=speculative_config,
            scheduler_config=scheduler_config,
        )

        self.min_n = vllm_config.speculative_config.prompt_lookup_min
        self.max_n = vllm_config.speculative_config.prompt_lookup_max
        self.k = vllm_config.speculative_config.num_speculative_tokens
        self.max_model_len = vllm_config.model_config.max_model_len
        self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        self.device = device

        self.kernel = NgramGPUKernel(
            vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device
        )
        self.kernel.to(device)
        self.kernel.eval()

        self._dummy_run()

    def _dummy_run(self):
        token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
            batch_size=self.max_num_seqs,
            max_seq_len=self.max_model_len,
            pattern_len=self.k,
            device=self.device,
        )

        combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n)

        for _ in range(3):
            with set_forward_context(None, self.vllm_config):
                _, _ = self.kernel(num_tokens, token_ids, combined_mask)

    def _generate_dummy_data(
        self,
        batch_size: int,
        max_seq_len: int,
        pattern_len: int,
        device: str = "cuda",
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Generate random test data with n-gram repetitions.

        Args:
            batch_size: Number of sequences in the batch
            max_seq_len: Maximum sequence length
            pattern_len: Length of patterns to inject for matching
            device: Device to place tensors on

        Returns:
            token_ids: [batch_size, max_seq_len] tensor
            num_tokens: [batch_size] tensor
            sampled_flags: [batch_size] bool tensor
            valid_mask: [batch_size] bool tensor
        """
        token_ids = torch.zeros(
            batch_size,
            max_seq_len,
            dtype=torch.int32,
            device=device,
        )

        num_tokens = torch.randint(
            pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
        )

        sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
        valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)

        return token_ids, num_tokens, sampled_flags, valid_mask

    def propose(
        self,
        num_tokens_no_spec: torch.Tensor,  # [batch_size]
        token_ids_gpu: torch.Tensor,  # [batch_size, max_len]
        valid_sampled_token_ids_gpu: torch.Tensor,  # [batch_size, num_spec_tokens + 1]
        valid_sampled_tokens_count: torch.Tensor,  # [batch_size]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Propose draft tokens using GPU-accelerated n-gram matching.

        Scatter sampled tokens into `token_ids_gpu`, compute temporary
        updated lengths, then run the kernel.

        Args:
            num_tokens_no_spec: Number of tokens per sequence (read-only)
            token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
            valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
            valid_sampled_tokens_count: Count of valid tokens per sequence

        Returns:
            draft_tokens: Proposed draft token IDs [batch_size, k]
            num_valid_draft_tokens: Count of leading valid draft tokens
                per request [batch_size]
        """
        assert token_ids_gpu.device == self.device
        assert num_tokens_no_spec.device == self.device

        batch_size = num_tokens_no_spec.shape[0]
        max_seq_len = token_ids_gpu.shape[1]
        max_new_tokens = valid_sampled_token_ids_gpu.shape[1]  # num_spec_tokens + 1

        # Scatter newly sampled tokens into token_ids_gpu.
        offsets = torch.arange(max_new_tokens, device=self.device)
        write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
        valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze(
            1
        )
        in_bounds = write_positions < max_seq_len
        scatter_mask = (
            valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds
        )

        write_positions_long = write_positions.clamp(max=max_seq_len - 1).long()
        existing_values = token_ids_gpu.gather(1, write_positions_long)

        tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype)
        tokens_to_scatter = torch.where(
            scatter_mask,
            tokens_cast,
            existing_values,
        )
        token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)

        num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count

        # Compute validity masks.
        sampled_flags = valid_sampled_tokens_count > 0
        valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)

        with set_forward_context(None, self.vllm_config):
            combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)

            with record_function_or_nullcontext("ngram_proposer_gpu: kernel"):
                draft_tokens, num_valid_draft_tokens = self.kernel(
                    num_tokens_tmp,
                    token_ids_gpu,
                    combined_mask,
                )

            return draft_tokens, num_valid_draft_tokens

    def update_token_ids_ngram(
        self,
        sampled_token_ids: torch.Tensor | list[list[int]],
        gpu_input_batch: InputBatch,
        token_ids_gpu: torch.Tensor,
        num_tokens_no_spec: torch.Tensor,
        discard_request_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Prepare speculative decoding inputs on device:
        compute next token ids and valid counts, honoring discarded requests
        and rejected tokens, without CPU-GPU sync.
        """
        num_reqs = gpu_input_batch.num_reqs

        if isinstance(sampled_token_ids, list):
            # When disable_padded_drafter_batch=True, sampled_token_ids is
            # an irregular list[list[int]] where sublists may have different
            # lengths (including empty lists for discarded requests).
            # Pad all sublists to the same length with -1 before converting
            # to tensor.
            max_len = max(
                (len(sublist) for sublist in sampled_token_ids),
                default=0,
            )
            # Ensure at least length 1 for tensor creation
            max_len = max(max_len, 1)
            padded_list = [
                sublist + [-1] * (max_len - len(sublist))
                for sublist in sampled_token_ids
            ]
            sampled_token_ids = torch.tensor(
                padded_list, dtype=torch.int32, device=self.device
            )
        assert isinstance(sampled_token_ids, torch.Tensor), (
            "sampled_token_ids should be a torch.Tensor for ngram_gpu"
        )

        # Backup last valid token before speculative tokens.
        backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
        backup_next_token_ids = torch.gather(
            token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
        ).squeeze(1)

        valid_sampled_token_ids_gpu = sampled_token_ids.clone()
        # Invalidate sampled tokens for discarded requests.
        discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1)
        valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1)

        # Mask valid tokens within each request.
        valid_mask = (valid_sampled_token_ids_gpu != -1) & (
            valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
        )

        # Count valid tokens per request.
        valid_sampled_tokens_count = valid_mask.sum(dim=1)

        # Rightmost valid index per row.
        last_valid_indices = valid_sampled_tokens_count - 1
        last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

        # Last valid token from each row; undefined if none.
        selected_tokens = torch.gather(
            valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
        ).squeeze(1)

        # Use last token if valid; otherwise fallback to backup.
        next_token_ids = torch.where(
            last_valid_indices != -1,
            selected_tokens,
            backup_next_token_ids,
        )

        return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu

    def load_model(self, *args, **kwargs):
        self.kernel.load_model(*args, **kwargs)

_generate_dummy_data

_generate_dummy_data(
    batch_size: int,
    max_seq_len: int,
    pattern_len: int,
    device: str = "cuda",
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Generate random test data with n-gram repetitions.

Parameters:

Name Type Description Default
batch_size int

Number of sequences in the batch

required
max_seq_len int

Maximum sequence length

required
pattern_len int

Length of patterns to inject for matching

required
device str

Device to place tensors on

'cuda'

Returns:

Name Type Description
token_ids Tensor

[batch_size, max_seq_len] tensor

num_tokens Tensor

[batch_size] tensor

sampled_flags Tensor

[batch_size] bool tensor

valid_mask Tensor

[batch_size] bool tensor

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def _generate_dummy_data(
    self,
    batch_size: int,
    max_seq_len: int,
    pattern_len: int,
    device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Generate random test data with n-gram repetitions.

    Args:
        batch_size: Number of sequences in the batch
        max_seq_len: Maximum sequence length
        pattern_len: Length of patterns to inject for matching
        device: Device to place tensors on

    Returns:
        token_ids: [batch_size, max_seq_len] tensor
        num_tokens: [batch_size] tensor
        sampled_flags: [batch_size] bool tensor
        valid_mask: [batch_size] bool tensor
    """
    token_ids = torch.zeros(
        batch_size,
        max_seq_len,
        dtype=torch.int32,
        device=device,
    )

    num_tokens = torch.randint(
        pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
    )

    sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
    valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)

    return token_ids, num_tokens, sampled_flags, valid_mask

propose

propose(
    num_tokens_no_spec: Tensor,
    token_ids_gpu: Tensor,
    valid_sampled_token_ids_gpu: Tensor,
    valid_sampled_tokens_count: Tensor,
) -> tuple[Tensor, Tensor]

Propose draft tokens using GPU-accelerated n-gram matching.

Scatter sampled tokens into token_ids_gpu, compute temporary updated lengths, then run the kernel.

Parameters:

Name Type Description Default
num_tokens_no_spec Tensor

Number of tokens per sequence (read-only)

required
token_ids_gpu Tensor

Token IDs tensor (modified in-place with new tokens)

required
valid_sampled_token_ids_gpu Tensor

Newly sampled tokens to scatter

required
valid_sampled_tokens_count Tensor

Count of valid tokens per sequence

required

Returns:

Name Type Description
draft_tokens Tensor

Proposed draft token IDs [batch_size, k]

num_valid_draft_tokens Tensor

Count of leading valid draft tokens per request [batch_size]

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def propose(
    self,
    num_tokens_no_spec: torch.Tensor,  # [batch_size]
    token_ids_gpu: torch.Tensor,  # [batch_size, max_len]
    valid_sampled_token_ids_gpu: torch.Tensor,  # [batch_size, num_spec_tokens + 1]
    valid_sampled_tokens_count: torch.Tensor,  # [batch_size]
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Propose draft tokens using GPU-accelerated n-gram matching.

    Scatter sampled tokens into `token_ids_gpu`, compute temporary
    updated lengths, then run the kernel.

    Args:
        num_tokens_no_spec: Number of tokens per sequence (read-only)
        token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
        valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
        valid_sampled_tokens_count: Count of valid tokens per sequence

    Returns:
        draft_tokens: Proposed draft token IDs [batch_size, k]
        num_valid_draft_tokens: Count of leading valid draft tokens
            per request [batch_size]
    """
    assert token_ids_gpu.device == self.device
    assert num_tokens_no_spec.device == self.device

    batch_size = num_tokens_no_spec.shape[0]
    max_seq_len = token_ids_gpu.shape[1]
    max_new_tokens = valid_sampled_token_ids_gpu.shape[1]  # num_spec_tokens + 1

    # Scatter newly sampled tokens into token_ids_gpu.
    offsets = torch.arange(max_new_tokens, device=self.device)
    write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
    valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze(
        1
    )
    in_bounds = write_positions < max_seq_len
    scatter_mask = (
        valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds
    )

    write_positions_long = write_positions.clamp(max=max_seq_len - 1).long()
    existing_values = token_ids_gpu.gather(1, write_positions_long)

    tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype)
    tokens_to_scatter = torch.where(
        scatter_mask,
        tokens_cast,
        existing_values,
    )
    token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)

    num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count

    # Compute validity masks.
    sampled_flags = valid_sampled_tokens_count > 0
    valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)

    with set_forward_context(None, self.vllm_config):
        combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)

        with record_function_or_nullcontext("ngram_proposer_gpu: kernel"):
            draft_tokens, num_valid_draft_tokens = self.kernel(
                num_tokens_tmp,
                token_ids_gpu,
                combined_mask,
            )

        return draft_tokens, num_valid_draft_tokens

update_token_ids_ngram

update_token_ids_ngram(
    sampled_token_ids: Tensor | list[list[int]],
    gpu_input_batch: InputBatch,
    token_ids_gpu: Tensor,
    num_tokens_no_spec: Tensor,
    discard_request_mask: Tensor,
) -> tuple[Tensor, Tensor, Tensor]

Prepare speculative decoding inputs on device: compute next token ids and valid counts, honoring discarded requests and rejected tokens, without CPU-GPU sync.

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def update_token_ids_ngram(
    self,
    sampled_token_ids: torch.Tensor | list[list[int]],
    gpu_input_batch: InputBatch,
    token_ids_gpu: torch.Tensor,
    num_tokens_no_spec: torch.Tensor,
    discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Prepare speculative decoding inputs on device:
    compute next token ids and valid counts, honoring discarded requests
    and rejected tokens, without CPU-GPU sync.
    """
    num_reqs = gpu_input_batch.num_reqs

    if isinstance(sampled_token_ids, list):
        # When disable_padded_drafter_batch=True, sampled_token_ids is
        # an irregular list[list[int]] where sublists may have different
        # lengths (including empty lists for discarded requests).
        # Pad all sublists to the same length with -1 before converting
        # to tensor.
        max_len = max(
            (len(sublist) for sublist in sampled_token_ids),
            default=0,
        )
        # Ensure at least length 1 for tensor creation
        max_len = max(max_len, 1)
        padded_list = [
            sublist + [-1] * (max_len - len(sublist))
            for sublist in sampled_token_ids
        ]
        sampled_token_ids = torch.tensor(
            padded_list, dtype=torch.int32, device=self.device
        )
    assert isinstance(sampled_token_ids, torch.Tensor), (
        "sampled_token_ids should be a torch.Tensor for ngram_gpu"
    )

    # Backup last valid token before speculative tokens.
    backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
    backup_next_token_ids = torch.gather(
        token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
    ).squeeze(1)

    valid_sampled_token_ids_gpu = sampled_token_ids.clone()
    # Invalidate sampled tokens for discarded requests.
    discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1)
    valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1)

    # Mask valid tokens within each request.
    valid_mask = (valid_sampled_token_ids_gpu != -1) & (
        valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
    )

    # Count valid tokens per request.
    valid_sampled_tokens_count = valid_mask.sum(dim=1)

    # Rightmost valid index per row.
    last_valid_indices = valid_sampled_tokens_count - 1
    last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

    # Last valid token from each row; undefined if none.
    selected_tokens = torch.gather(
        valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
    ).squeeze(1)

    # Use last token if valid; otherwise fallback to backup.
    next_token_ids = torch.where(
        last_valid_indices != -1,
        selected_tokens,
        backup_next_token_ids,
    )

    return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu

_sync_num_tokens

_sync_num_tokens(
    input_batch: InputBatch,
    num_tokens_no_spec_gpu: Tensor,
    active_idx_cpu: Tensor,
    active_idx_gpu: Tensor,
    n_active: int,
    device: device,
    _pinned_val_buf: Tensor,
) -> None

Batch-sync GPU sequence lengths from CPU source of truth.

Inputs

input_batch: Batch container with CPU length tensor. num_tokens_no_spec_gpu: Destination GPU length tensor. active_idx_cpu: Active request indices on CPU. active_idx_gpu: Active request indices on GPU. n_active: Number of active requests. device: Target CUDA device. _pinned_val_buf: Resident pinned int32 staging buffer.

Outputs: None (updates num_tokens_no_spec_gpu in-place).

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def _sync_num_tokens(
    input_batch: InputBatch,
    num_tokens_no_spec_gpu: torch.Tensor,
    active_idx_cpu: torch.Tensor,
    active_idx_gpu: torch.Tensor,
    n_active: int,
    device: torch.device,
    _pinned_val_buf: torch.Tensor,
) -> None:
    """Batch-sync GPU sequence lengths from CPU source of truth.

    Inputs:
        input_batch: Batch container with CPU length tensor.
        num_tokens_no_spec_gpu: Destination GPU length tensor.
        active_idx_cpu: Active request indices on CPU.
        active_idx_gpu: Active request indices on GPU.
        n_active: Number of active requests.
        device: Target CUDA device.
        _pinned_val_buf: Resident pinned int32 staging buffer.
    Outputs:
        None (updates num_tokens_no_spec_gpu in-place).
    """
    src_cpu = input_batch.num_tokens_no_spec_cpu_tensor
    vals = _pinned_val_buf[:n_active]
    vals.copy_(src_cpu.index_select(0, active_idx_cpu))

    num_tokens_no_spec_gpu.index_copy_(
        0,
        active_idx_gpu,
        vals.to(device=device, non_blocking=True),
    )

copy_num_valid_draft_tokens

copy_num_valid_draft_tokens(
    num_valid_draft_tokens_cpu: Tensor,
    num_valid_draft_tokens_copy_stream: Stream,
    num_valid_draft_tokens_event: Event,
    num_valid_draft_tokens: Tensor | None,
    batch_size: int,
) -> None

Async D2H copy of per-request valid draft counts.

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def copy_num_valid_draft_tokens(
    num_valid_draft_tokens_cpu: torch.Tensor,
    num_valid_draft_tokens_copy_stream: torch.cuda.Stream,
    num_valid_draft_tokens_event: torch.cuda.Event,
    num_valid_draft_tokens: torch.Tensor | None,
    batch_size: int,
) -> None:
    """
    Async D2H copy of per-request valid draft counts.
    """
    if num_valid_draft_tokens is None:
        return

    num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0])
    if num_reqs_to_copy <= 0:
        return

    default_stream = torch.cuda.current_stream()
    with torch.cuda.stream(num_valid_draft_tokens_copy_stream):
        num_valid_draft_tokens_copy_stream.wait_stream(default_stream)
        num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_(
            num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True
        )
        num_valid_draft_tokens_event.record()

update_ngram_gpu_tensors_incremental

update_ngram_gpu_tensors_incremental(
    input_batch: InputBatch,
    token_ids_gpu_tensor: Tensor,
    num_tokens_no_spec_gpu: Tensor,
    new_reqs: list[CachedRequestState],
    device: device,
    _pinned_idx_buf: Tensor,
    _pinned_val_buf: Tensor,
) -> None

Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu for ngram GPU proposer.

Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def update_ngram_gpu_tensors_incremental(
    input_batch: InputBatch,
    token_ids_gpu_tensor: torch.Tensor,
    num_tokens_no_spec_gpu: torch.Tensor,
    new_reqs: list[CachedRequestState],
    device: torch.device,
    _pinned_idx_buf: torch.Tensor,
    _pinned_val_buf: torch.Tensor,
) -> None:
    """Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
    for ngram GPU proposer.
    """
    prev_req_id_to_index = input_batch.prev_req_id_to_index
    curr_req_id_to_index = input_batch.req_id_to_index

    if not curr_req_id_to_index:
        return

    active_indices = list(curr_req_id_to_index.values())
    n_active = len(active_indices)

    # Use resident pinned buffers to avoid per-call allocation.
    active_idx_cpu = _pinned_idx_buf[:n_active]
    active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long))

    active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True)

    new_req_ids = {req.req_id for req in new_reqs}

    # First run, no previous state.
    if prev_req_id_to_index is None:
        for idx in active_indices:
            num_tokens = input_batch.num_tokens_no_spec[idx]
            if num_tokens > 0:
                token_ids_gpu_tensor[idx, :num_tokens].copy_(
                    input_batch.token_ids_cpu_tensor[idx, :num_tokens],
                    non_blocking=True,
                )

        _sync_num_tokens(
            input_batch,
            num_tokens_no_spec_gpu,
            active_idx_cpu,
            active_idx_gpu,
            n_active,
            device,
            _pinned_val_buf,
        )
        return

    # Detect index changes for reorder.
    reorder_src: list[int] = []
    reorder_dst: list[int] = []

    for req_id, curr_idx in curr_req_id_to_index.items():
        if req_id in new_req_ids:
            continue
        prev_idx = prev_req_id_to_index.get(req_id)
        if prev_idx is not None and prev_idx != curr_idx:
            reorder_src.append(prev_idx)
            reorder_dst.append(curr_idx)

    if reorder_src:
        src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device)
        dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device)

        temp_token_ids = token_ids_gpu_tensor[src_tensor].clone()
        temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone()

        token_ids_gpu_tensor[dst_tensor] = temp_token_ids
        num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens

    # Full copy for new/resumed requests.
    for req_state in new_reqs:
        new_req_idx = curr_req_id_to_index.get(req_state.req_id)
        if new_req_idx is None:
            continue

        num_tokens = input_batch.num_tokens_no_spec[new_req_idx]
        if num_tokens > 0:
            token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_(
                input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens],
                non_blocking=True,
            )

    # Always batch-sync sequence lengths from CPU for ALL active requests.
    _sync_num_tokens(
        input_batch,
        num_tokens_no_spec_gpu,
        active_idx_cpu,
        active_idx_gpu,
        n_active,
        device,
        _pinned_val_buf,
    )

update_scheduler_for_invalid_drafts

update_scheduler_for_invalid_drafts(
    num_valid_draft_tokens_event: Event,
    num_valid_draft_tokens_cpu: Tensor,
    scheduler_output: SchedulerOutput,
    req_id_to_index: dict[str, int],
) -> None

Trim invalid speculative slots using per-request valid draft counts.

Parameters:

Name Type Description Default
num_valid_draft_tokens_event Event

Event for async D2H completion.

required
num_valid_draft_tokens_cpu Tensor

CPU buffer of valid draft counts.

required
scheduler_output SchedulerOutput

Scheduler metadata to update in-place.

required
req_id_to_index dict[str, int]

Request-id to batch-index mapping.

required
Source code in vllm/v1/spec_decode/ngram_proposer_gpu.py
def update_scheduler_for_invalid_drafts(
    num_valid_draft_tokens_event: torch.cuda.Event,
    num_valid_draft_tokens_cpu: torch.Tensor,
    scheduler_output: "SchedulerOutput",
    req_id_to_index: dict[str, int],
) -> None:
    """Trim invalid speculative slots using per-request valid draft counts.

    Args:
        num_valid_draft_tokens_event: Event for async D2H completion.
        num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
        scheduler_output: Scheduler metadata to update in-place.
        req_id_to_index: Request-id to batch-index mapping.
    """
    req_data = scheduler_output.scheduled_cached_reqs
    num_valid_draft_tokens_event.synchronize()

    for req_id in req_data.req_ids:
        req_index = req_id_to_index.get(req_id)
        if req_index is None:
            continue

        spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
        if spec_token_ids is None:
            continue

        scheduled_k = len(spec_token_ids)

        valid_k = int(num_valid_draft_tokens_cpu[req_index].item())
        valid_k = max(0, min(valid_k, scheduled_k))

        tokens_to_trim = scheduled_k - valid_k
        scheduler_output.total_num_scheduled_tokens -= tokens_to_trim
        scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim

        if valid_k == 0:
            scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
        else:
            scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[
                :valid_k
            ]