Skip to content

vllm.lora.punica_wrapper.utils

compute_meta

compute_meta(
    token_lora_tensor: Tensor,
) -> tuple[Tensor, Tensor, Tensor, int, int, int, bool]

Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function will combine them into a single request, improving sgmv kernel inference performance. 2. At the beginning of each prefill stage inference, recalculations are needed based on the input, but only once.

Source code in vllm/lora/punica_wrapper/utils.py
def compute_meta(
    token_lora_tensor: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
    """
    Get the information required for the sgmv kernel. With the  features:
    1. If consecutive requests in the batch use the same LoRA, this function
    will combine them into a single request, improving sgmv kernel inference
    performance.
    2. At the beginning of each prefill stage inference, recalculations are
    needed based on the input, but only once.
    """

    lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
        token_lora_tensor, return_counts=True)
    cum_result = torch.cumsum(seq_length_tensor, dim=0)
    b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
    b_seq_start_tensor[1:].copy_(cum_result[:-1])
    max_length = seq_length_tensor.max().item()
    token_nums = seq_length_tensor.sum().item()
    batch_size = lora_indices_tensor.size(0)
    no_lora = False
    # -1 means no lora should be applied. Use `no_lora` to determine whether
    # the current step requires LoRA. If LoRA is not needed, the prefill stage
    # does not need to launch the triton kernel, which can improve performance
    if batch_size == 1 and lora_indices_tensor == -1:
        no_lora = True
    return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
            batch_size, max_length, token_nums, no_lora)

convert_mapping

convert_mapping(
    mapping: LoRAMapping,
    lora_index_to_id: list[Optional[int]],
    max_loras: int,
    vocab_size: int,
    extra_vocab_size: int,
    device: device,
    long_lora_context: Optional[
        LongContextLoRAContext
    ] = None,
) -> tuple[
    Tensor,
    Tensor,
    Tensor,
    Tensor,
    Optional[Tensor],
    list[int],
]

Converts LoRAMapping to index tensors.

Parameters:

Name Type Description Default
mapping LoRAMapping

LoRAMapping mapping rows in a batch to LoRA ids.

required
lora_index_to_id list[Optional[int]]

List mapping LoRA ids to LoRA indices.

required
max_loras int

Maximum number of LoRAs.

required
vocab_size int

Model vocab size.

required
extra_vocab_size int

Extra vocab size each LoRA can have.

required
long_lora_context Optional[LongContextLoRAContext]

Passed if there are long context lora in a batch.

None

Returns:

Type Description
tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor], list[int]]

A tuple of tensors: base_indices: Tensor of shape [batch_size] mapping batch rows to LoRA indices. sampler_indices: Tensor of shape [batch_size] mapping requests to LoRA indices for sampler. For generation, this will be the same as base_indices. For prefill, this will map requests to LoRA indices. sampler_indices_padded: Tensor of shape [batch_size] mapping requests to LoRA indices for sampler with padding. Same as sampler_indices, but -1 is replaced with max_loras. embeddings_indices: Tensor of shape [2, batch_size] mapping requests to embedding indices. First row is for embeddings added by the LoRAs, second row is for the LoRA.lora_a embeddings. long_lora_indices: Tensor of shape [batch_size] mapping requests to RoPE offsets and rot dims for long LoRAs. None if long context lora doesn't exist. indices_len: List of lengths of the above tensors. It contains (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_indices).

Source code in vllm/lora/punica_wrapper/utils.py
def convert_mapping(
    mapping: "LoRAMapping",
    lora_index_to_id: list[Optional[int]],
    max_loras: int,
    vocab_size: int,
    extra_vocab_size: int,
    device: torch.device,
    long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
           Optional[torch.Tensor], list[int]]:
    """Converts LoRAMapping to index tensors.

    Args:
        mapping: LoRAMapping mapping rows in a batch to LoRA ids.
        lora_index_to_id: List mapping LoRA ids to LoRA indices.
        max_loras: Maximum number of LoRAs.
        vocab_size: Model vocab size.
        extra_vocab_size: Extra vocab size each LoRA can have.
        long_lora_context: Passed if there are long context lora in a batch.

    Returns:
        A tuple of tensors:
            base_indices: Tensor of shape [batch_size] mapping batch rows to
                LoRA indices.
            sampler_indices: Tensor of shape [batch_size] mapping requests to
                LoRA indices for sampler. For generation, this will be the
                same as base_indices. For prefill, this will map requests
                to LoRA indices.
            sampler_indices_padded: Tensor of shape [batch_size] mapping
                requests to LoRA indices for sampler with padding.
                Same as sampler_indices, but -1 is replaced with
                max_loras.
            embeddings_indices: Tensor of shape [2, batch_size] mapping
                requests to embedding indices. First row is for embeddings
                added by the LoRAs, second row is for the LoRA.lora_a
                embeddings.
            long_lora_indices: Tensor of shape [batch_size] mapping
                requests to RoPE offsets and rot dims for long LoRAs.
                None if long context lora doesn't exist.
            indices_len: List of lengths of the above tensors. It contains
                (base_indices, sampler_indices, sampler_indices_padded,
                embeddings_indices, long_lora_indices).
    """
    index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
    embedding_indices = index_mapping_indices.copy()
    lora_indices = index_mapping_indices.copy()
    long_lora_offsets: Optional[torch.Tensor] = None
    if long_lora_context:
        long_lora_offsets = torch.zeros(len(index_mapping_indices),
                                        device=device,
                                        dtype=torch.long)
    prompt_mapping: list[int] = [
        lora_index_to_id.index(x) if x > 0 else -1
        for x in mapping.prompt_mapping
    ]
    lora_idx = None
    for i in range(len(index_mapping_indices)):
        # TODO index can be slow. optimize
        lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
                    if index_mapping_indices[i] > 0 else -1)
        embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
        lora_indices[i] = lora_idx
        if long_lora_context:
            assert long_lora_offsets is not None
            lora_offset: int = long_lora_context.offsets_by_lora_id.get(
                index_mapping_indices[i], 0)
            long_lora_offsets[i] = lora_offset

    indices_list: list[Union[list[int], torch.Tensor]] = [
        index_mapping_indices,
        lora_indices,
        embedding_indices,
    ]
    if long_lora_context:
        assert long_lora_offsets is not None
        indices_list.append(long_lora_offsets)
    indices = torch.tensor(indices_list, dtype=torch.long, device=device)
    prompt_mapping_tensor = torch.tensor(prompt_mapping,
                                         dtype=torch.long,
                                         device=device)
    embeddings_indices = torch.stack([
        indices[2] * extra_vocab_size,
        indices[2] * (vocab_size + extra_vocab_size),
    ])
    embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1,
                                     embeddings_indices)
    base_indices = indices[1]
    sampler_indices = prompt_mapping_tensor
    sampler_indices_padded = sampler_indices.clone()
    sampler_indices_padded = torch.where(sampler_indices_padded == -1,
                                         max_loras - 1, sampler_indices_padded)
    sampler_indices_padded = torch.arange(
        0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
            sampler_indices_padded * len(sampler_indices_padded))
    long_lora_indices = None
    long_lora_indices_len: Optional[int] = None
    if long_lora_context:
        long_lora_indices = indices[3]
        long_lora_indices_len = long_lora_indices.shape[-1]
    # Contain length of indices tensors. Used to index into each tensor.
    indices_len = [
        base_indices.shape[-1],
        sampler_indices.shape[-1],
        sampler_indices_padded.shape[-1],
        embeddings_indices.shape[-1],
    ]
    if long_lora_indices_len is not None:
        indices_len.append(long_lora_indices_len)
    else:
        # If long_lora doesn't exist,append None
        indices_len.append(None)

    return (
        base_indices,
        sampler_indices,
        sampler_indices_padded,
        embeddings_indices,
        long_lora_indices,
        indices_len,
    )