Skip to content

vllm.lora.ops.xla_ops.lora_ops

bgmv_expand

bgmv_expand(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    add_inputs: bool = True,
)

Parameters:

Name Type Description Default
inputs Tensor

Input tensor of shape [num_tokens, hidden_size].

required
lora_b_weights Tensor

LoRA weights of shape [num_loras, lora_rank, hidden_size].

required
output_tensor Tensor

output tensor of shape [num_tokens, hidden_size * num_slices].

required
lora_indices_tensor Tensor

Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token.

required
add_inputs bool

Whether or not to add the input tensor to the output tensor.

True
Source code in vllm/lora/ops/xla_ops/lora_ops.py
def bgmv_expand(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    add_inputs: bool = True,
):
    """
    Args:
        inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].

        lora_b_weights (torch.Tensor): LoRA weights of shape
            [num_loras, lora_rank, hidden_size].

        output_tensor (torch.Tensor): output tensor of shape
            [num_tokens, hidden_size * num_slices].

        lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
            indicating which LoRA matrix to use for each token.
        add_inputs (bool): Whether or not to add the input tensor to the output
            tensor.
    """

    outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)

    limit = output_tensor.shape[0]
    if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
        limit = 1

    if output_tensor.shape[1] > outputs.shape[1]:
        outputs = F.pad(outputs,
                        (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))

    if add_inputs:
        return output_tensor + outputs[:limit, :output_tensor.shape[1]]
    else:
        return outputs[:limit, :output_tensor.shape[1]]

bgmv_expand_slice

bgmv_expand_slice(
    inputs: Tensor,
    lora_b_weights: Tensor,
    output_tensor: Tensor,
    lora_indices_tensor: Tensor,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = True,
)

Parameters:

Name Type Description Default
inputs Tensor

Input tensor of shape [num_tokens, hidden_size].

required
lora_b_weights Tensor

LoRA weights of shape [num_loras, lora_rank, hidden_size].

required
output_tensor Tensor

output tensor of shape [num_tokens, hidden_size * num_slices].

required
lora_indices_tensor Tensor

Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token.

required
add_inputs bool

Whether or not to add the input tensor to the output tensor.

True
Source code in vllm/lora/ops/xla_ops/lora_ops.py
def bgmv_expand_slice(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = True,
):
    """
    Args:
        inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].

        lora_b_weights (torch.Tensor): LoRA weights of shape
            [num_loras, lora_rank, hidden_size].

        output_tensor (torch.Tensor): output tensor of shape
            [num_tokens, hidden_size * num_slices].

        lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
            indicating which LoRA matrix to use for each token.
        add_inputs (bool): Whether or not to add the input tensor to the output
            tensor.
    """
    outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)

    outputs = F.pad(
        outputs,
        (
            slice_offset,
            output_tensor.shape[1] - (slice_offset + slice_size),
            0,
            0,
        ),
    )

    if add_inputs:
        return output_tensor + outputs
    else:
        return outputs

bgmv_jax

bgmv_jax(inputs, loras, idxs)
Source code in vllm/lora/ops/xla_ops/lora_ops.py
@jax.jit
def bgmv_jax(inputs, loras, idxs):
    return jnp.einsum(
        "td,tX,Xld->tl",
        inputs,
        jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
        loras,
    )

bgmv_non_xla

bgmv_non_xla(
    inputs: Tensor, loras: Tensor, idxs: IntTensor
)
Source code in vllm/lora/ops/xla_ops/lora_ops.py
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
                 idxs: torch.IntTensor):
    T, _ = inputs.shape
    if len(loras.shape) == 4:
        loras = loras.squeeze(axis=1)
    _, L, _ = loras.shape

    return torch.empty((T, L), device=inputs.device)

bgmv_shrink

bgmv_shrink(
    inputs: Tensor,
    lora_b_weights: Tensor,
    lora_indices_tensor: Tensor,
    scaling: float = 1.0,
)

Parameters:

Name Type Description Default
inputs Tensor

Input tensor of shape [num_tokens, hidden_size].

required
lora_b_weights Tensor

LoRA weights of shape [num_loras, lora_rank, hidden_size].

required
output_tensor Tensor

(Unused) output tensor (placeholder).

required
lora_indices_tensor Tensor

Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token.

required
scaling float

Scalar multiplier applied to the output.

1.0
Source code in vllm/lora/ops/xla_ops/lora_ops.py
def bgmv_shrink(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    scaling: float = 1.0,
):
    """
    Args:
        inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
        lora_b_weights (torch.Tensor): LoRA weights of shape
            [num_loras, lora_rank, hidden_size].
        output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
        lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
            indicating which LoRA matrix to use for each token.
        scaling (float, optional): Scalar multiplier applied to the output.
    """

    return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights,
                                        lora_indices_tensor)

bgmv_xla

bgmv_xla(inputs: Tensor, loras: Tensor, idxs: IntTensor)
Source code in vllm/lora/ops/xla_ops/lora_ops.py
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
    if len(loras.shape) == 4:
        loras = loras.squeeze(axis=1)

    jax_import_guard()
    return xb.call_jax(bgmv_jax, (inputs, loras, idxs))