Skip to content

vllm.lora.ops.torch_ops

Modules:

Name Description
lora_ops

__all__ module-attribute

__all__ = [
    "bgmv_expand",
    "bgmv_expand_slice",
    "bgmv_shrink",
    "sgmv_expand",
    "sgmv_expand_slice",
    "sgmv_shrink",
]

bgmv_expand

bgmv_expand(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    add_inputs: bool = True,
)
Source code in vllm/lora/ops/torch_ops/lora_ops.py
def bgmv_expand(inputs: torch.Tensor,
                lora_b_weights: torch.Tensor,
                output_tensor: torch.Tensor,
                lora_indices_tensor: torch.Tensor,
                add_inputs: bool = True):
    selected_loras = lora_b_weights[lora_indices_tensor].to(
        dtype=output_tensor.dtype)
    if len(selected_loras.shape) == 4:
        selected_loras = selected_loras.squeeze(dim=1)
    inputs = inputs.to(dtype=output_tensor.dtype)
    outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

    limit = output_tensor.shape[0]
    if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
        limit = 1

    # LoRA adapter and model may add different amounts of padding to output
    common_len = min(outputs.shape[1], output_tensor.shape[1])

    if add_inputs:
        output_tensor[:, :common_len] += outputs[:limit, :common_len]
    else:
        output_tensor[:, :common_len] = outputs[:limit, :common_len]

bgmv_expand_slice

bgmv_expand_slice(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = True,
)
Source code in vllm/lora/ops/torch_ops/lora_ops.py
def bgmv_expand_slice(inputs: torch.Tensor,
                      lora_b_weights: torch.Tensor,
                      output_tensor: torch.Tensor,
                      lora_indices_tensor: torch.Tensor,
                      slice_offset: int,
                      slice_size: int,
                      add_inputs: bool = True):
    selected_loras = lora_b_weights[lora_indices_tensor].to(
        dtype=output_tensor.dtype)
    inputs = inputs.to(dtype=output_tensor.dtype)
    if len(selected_loras.shape) == 4:
        selected_loras = selected_loras.squeeze(dim=1)
    outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

    if add_inputs:
        output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
    else:
        output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]

bgmv_shrink

bgmv_shrink(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    scaling: float = 1.0,
)
Source code in vllm/lora/ops/torch_ops/lora_ops.py
def bgmv_shrink(inputs: torch.Tensor,
                lora_b_weights: torch.Tensor,
                output_tensor: torch.Tensor,
                lora_indices_tensor: torch.Tensor,
                scaling: float = 1.0):
    selected_loras = lora_b_weights[lora_indices_tensor].to(
        dtype=output_tensor.dtype)
    if len(selected_loras.shape) == 4:
        selected_loras = selected_loras.squeeze(dim=1)
    inputs = inputs.to(dtype=output_tensor.dtype)
    outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)

    output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]

sgmv_expand

sgmv_expand(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    b_seq_start_loc: Tensor,
    seq_len_tensor: Tensor,
    lora_indices_tensor: Tensor,
    batches: int,
    max_seq_length: int,
    token_nums: int,
    add_inputs: bool = False,
)
Source code in vllm/lora/ops/torch_ops/lora_ops.py
def sgmv_expand(inputs: torch.Tensor,
                lora_b_weights: torch.Tensor,
                output_tensor: torch.Tensor,
                b_seq_start_loc: torch.Tensor,
                seq_len_tensor: torch.Tensor,
                lora_indices_tensor: torch.Tensor,
                batches: int,
                max_seq_length: int,
                token_nums: int,
                add_inputs: bool = False):
    exploded_indices = torch.repeat_interleave(lora_indices_tensor,
                                               seq_len_tensor)

    bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
                add_inputs)

sgmv_expand_slice

sgmv_expand_slice(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    b_seq_start_loc: Tensor,
    seq_len_tensor: Tensor,
    lora_indices_tensor: Tensor,
    batches: int,
    max_seq_length: int,
    token_nums: int,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = False,
)
Source code in vllm/lora/ops/torch_ops/lora_ops.py
def sgmv_expand_slice(inputs: torch.Tensor,
                      lora_b_weights: torch.Tensor,
                      output_tensor: torch.Tensor,
                      b_seq_start_loc: torch.Tensor,
                      seq_len_tensor: torch.Tensor,
                      lora_indices_tensor: torch.Tensor,
                      batches: int,
                      max_seq_length: int,
                      token_nums: int,
                      slice_offset: int,
                      slice_size: int,
                      add_inputs: bool = False):
    exploded_indices = torch.repeat_interleave(lora_indices_tensor,
                                               seq_len_tensor)

    bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
                      slice_offset, slice_size, add_inputs)

sgmv_shrink

sgmv_shrink(
    inputs: Tensor,
    lora_a_weights: Tensor,
    output_tensor: Tensor,
    b_seq_start_loc: Tensor,
    seq_len_tensor: Tensor,
    lora_indices_tensor: Tensor,
    batches: int,
    max_seq_length: int,
    token_nums: int,
    scaling: float,
)
Source code in vllm/lora/ops/torch_ops/lora_ops.py
def sgmv_shrink(
    inputs: torch.Tensor,
    lora_a_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    batches: int,
    max_seq_length: int,
    token_nums: int,
    scaling: float,
):
    exploded_indices = torch.repeat_interleave(lora_indices_tensor,
                                               seq_len_tensor)

    bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
                scaling)