Custom ops for prefetch offloader torch.compile + CUDA graph compatibility.
These ops use mutates_args to create data dependencies that prevent the compiler from reordering prefetch/sync operations.
_start_prefetch_fake
_start_prefetch_fake(
output_tensor: Tensor, layer_idx: int
) -> None
Fake implementation for torch.compile tracing.
Source code in vllm/model_executor/offloader/prefetch_ops.py
| def _start_prefetch_fake(
output_tensor: torch.Tensor,
layer_idx: int,
) -> None:
"""Fake implementation for torch.compile tracing."""
return
|
_start_prefetch_impl
_start_prefetch_impl(
output_tensor: Tensor, layer_idx: int
) -> None
Start async prefetch of layer_idx weights.
Initiates H2D copy on the copy stream for the specified layer.
Parameters:
| Name | Type | Description | Default |
output_tensor | Tensor | Output from forward - declared as mutated to prevent torch.compile from reordering this op before the computation that produces output_tensor. | required |
layer_idx | int | Index of the layer to prefetch. | required |
Source code in vllm/model_executor/offloader/prefetch_ops.py
| def _start_prefetch_impl(
output_tensor: torch.Tensor,
layer_idx: int,
) -> None:
"""Start async prefetch of layer_idx weights.
Initiates H2D copy on the copy stream for the specified layer.
Args:
output_tensor: Output from forward - declared as mutated to
prevent torch.compile from reordering this op before the
computation that produces output_tensor.
layer_idx: Index of the layer to prefetch.
"""
get_offloader()._start_prefetch(layer_idx)
|
_wait_prefetch_fake
_wait_prefetch_fake(
input_tensor: Tensor, layer_idx: int
) -> None
Fake implementation for torch.compile tracing.
Source code in vllm/model_executor/offloader/prefetch_ops.py
| def _wait_prefetch_fake(
input_tensor: torch.Tensor,
layer_idx: int,
) -> None:
"""Fake implementation for torch.compile tracing."""
return
|
_wait_prefetch_impl
_wait_prefetch_impl(
input_tensor: Tensor, layer_idx: int
) -> None
Wait for prefetch of layer_idx to complete.
Synchronizes the compute stream with the copy stream to ensure the prefetched weights are ready for use.
Parameters:
| Name | Type | Description | Default |
input_tensor | Tensor | Input to the layer (e.g., hidden_states) - declared as mutated to create data dependency for torch.compile. | required |
layer_idx | int | Index of the layer to wait for. | required |
Source code in vllm/model_executor/offloader/prefetch_ops.py
| def _wait_prefetch_impl(
input_tensor: torch.Tensor,
layer_idx: int,
) -> None:
"""Wait for prefetch of layer_idx to complete.
Synchronizes the compute stream with the copy stream to ensure
the prefetched weights are ready for use.
Args:
input_tensor: Input to the layer (e.g., hidden_states) - declared
as mutated to create data dependency for torch.compile.
layer_idx: Index of the layer to wait for.
"""
get_offloader()._wait_for_layer(layer_idx)
|
register_prefetch_offloader_ops
register_prefetch_offloader_ops() -> None
Register custom ops for prefetch offloader.
Must be called before the ops are used. This is typically done at module import time.
Source code in vllm/model_executor/offloader/prefetch_ops.py
| def register_prefetch_offloader_ops() -> None:
"""Register custom ops for prefetch offloader.
Must be called before the ops are used. This is typically done
at module import time.
"""
direct_register_custom_op(
op_name="wait_prefetch",
op_func=_wait_prefetch_impl,
mutates_args=["input_tensor"],
fake_impl=_wait_prefetch_fake,
)
direct_register_custom_op(
op_name="start_prefetch",
op_func=_start_prefetch_impl,
mutates_args=["output_tensor"],
fake_impl=_start_prefetch_fake,
)
|