Skip to content

vllm.model_executor.offloader.prefetch_ops

Custom ops for prefetch offloader torch.compile + CUDA graph compatibility.

These ops use mutates_args to create data dependencies that prevent the compiler from reordering prefetch/sync operations.

_start_prefetch_fake

_start_prefetch_fake(
    output_tensor: Tensor, layer_idx: int
) -> None

Fake implementation for torch.compile tracing.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def _start_prefetch_fake(
    output_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Fake implementation for torch.compile tracing."""
    return

_start_prefetch_impl

_start_prefetch_impl(
    output_tensor: Tensor, layer_idx: int
) -> None

Start async prefetch of layer_idx weights.

Initiates H2D copy on the copy stream for the specified layer.

Parameters:

Name Type Description Default
output_tensor Tensor

Output from forward - declared as mutated to prevent torch.compile from reordering this op before the computation that produces output_tensor.

required
layer_idx int

Index of the layer to prefetch.

required
Source code in vllm/model_executor/offloader/prefetch_ops.py
def _start_prefetch_impl(
    output_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Start async prefetch of layer_idx weights.

    Initiates H2D copy on the copy stream for the specified layer.

    Args:
        output_tensor: Output from forward - declared as mutated to
            prevent torch.compile from reordering this op before the
            computation that produces output_tensor.
        layer_idx: Index of the layer to prefetch.
    """
    get_offloader()._start_prefetch(layer_idx)

_wait_prefetch_fake

_wait_prefetch_fake(
    input_tensor: Tensor, layer_idx: int
) -> None

Fake implementation for torch.compile tracing.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def _wait_prefetch_fake(
    input_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Fake implementation for torch.compile tracing."""
    return

_wait_prefetch_impl

_wait_prefetch_impl(
    input_tensor: Tensor, layer_idx: int
) -> None

Wait for prefetch of layer_idx to complete.

Synchronizes the compute stream with the copy stream to ensure the prefetched weights are ready for use.

Parameters:

Name Type Description Default
input_tensor Tensor

Input to the layer (e.g., hidden_states) - declared as mutated to create data dependency for torch.compile.

required
layer_idx int

Index of the layer to wait for.

required
Source code in vllm/model_executor/offloader/prefetch_ops.py
def _wait_prefetch_impl(
    input_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Wait for prefetch of layer_idx to complete.

    Synchronizes the compute stream with the copy stream to ensure
    the prefetched weights are ready for use.

    Args:
        input_tensor: Input to the layer (e.g., hidden_states) - declared
            as mutated to create data dependency for torch.compile.
        layer_idx: Index of the layer to wait for.
    """
    get_offloader()._wait_for_layer(layer_idx)

register_prefetch_offloader_ops

register_prefetch_offloader_ops() -> None

Register custom ops for prefetch offloader.

Must be called before the ops are used. This is typically done at module import time.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def register_prefetch_offloader_ops() -> None:
    """Register custom ops for prefetch offloader.

    Must be called before the ops are used. This is typically done
    at module import time.
    """
    direct_register_custom_op(
        op_name="wait_prefetch",
        op_func=_wait_prefetch_impl,
        mutates_args=["input_tensor"],
        fake_impl=_wait_prefetch_fake,
    )

    direct_register_custom_op(
        op_name="start_prefetch",
        op_func=_start_prefetch_impl,
        mutates_args=["output_tensor"],
        fake_impl=_start_prefetch_fake,
    )