Skip to content

vllm.model_executor.offloader.prefetch

Prefetch-based CPU offloading with async prefetching.

Uses static buffers and event-based stream forking for torch.compile + CUDA graph compatibility. Events allow the copy stream to join CUDA graph captures, ensuring H2D copies are properly captured.

ParamInfo dataclass

Metadata about an offloaded parameter.

Source code in vllm/model_executor/offloader/prefetch.py
@dataclass
class ParamInfo:
    """Metadata about an offloaded parameter."""

    name: str
    shape: tuple[int, ...]
    stride: tuple[int, ...]
    dtype: torch.dtype

    @property
    def key(self) -> tuple[str, tuple[int, ...], tuple[int, ...], torch.dtype]:
        """Unique key for buffer pool grouping.

        Includes parameter name to prevent different parameters with the same
        shape from sharing buffers within the same layer. Parameters with the
        same name across different layers will share buffers (via slots).

        Includes stride because parameters with same shape but different
        strides need separate buffers to preserve memory layout.
        """
        return (self.name, self.shape, self.stride, self.dtype)

    @property
    def num_bytes(self) -> int:
        """Size in bytes."""
        numel = 1
        for dim in self.shape:
            numel *= dim
        return numel * torch.finfo(self.dtype).bits // 8

key property

key: tuple[str, tuple[int, ...], tuple[int, ...], dtype]

Unique key for buffer pool grouping.

Includes parameter name to prevent different parameters with the same shape from sharing buffers within the same layer. Parameters with the same name across different layers will share buffers (via slots).

Includes stride because parameters with same shape but different strides need separate buffers to preserve memory layout.

num_bytes property

num_bytes: int

Size in bytes.

PrefetchOffloader

Bases: BaseOffloader

Prefetching-based offloader with group-based layer selection.

Groups layers and uses async H2D prefetch to hide transfer latency. Uses static buffers and stream synchronization for torch.compile and CUDA graph compatibility.

Parameters:

Name Type Description Default
group_size int

Group every N layers together.

required
num_in_group int

Offload this many layers per group (last N of each group).

required
prefetch_step int

Number of layers to prefetch ahead.

required
mode str

Offload mode ("cpu" is currently supported).

'cpu'
Source code in vllm/model_executor/offloader/prefetch.py
class PrefetchOffloader(BaseOffloader):
    """Prefetching-based offloader with group-based layer selection.

    Groups layers and uses async H2D prefetch to hide transfer latency.
    Uses static buffers and stream synchronization for torch.compile and
    CUDA graph compatibility.

    Args:
        group_size: Group every N layers together.
        num_in_group: Offload this many layers per group (last N of each group).
        prefetch_step: Number of layers to prefetch ahead.
        mode: Offload mode ("cpu" is currently supported).
    """

    def __init__(
        self,
        group_size: int,
        num_in_group: int,
        prefetch_step: int,
        offload_params: set[str] | None = None,
        mode: str = "cpu",
    ):
        self.group_size = group_size
        self.num_in_group = num_in_group
        self.prefetch_step = prefetch_step
        self.offload_params = offload_params or set()
        self.mode = mode

        # Copy stream for async H2D transfers
        self.copy_stream = torch.cuda.Stream()

        # Module offloaders and buffer pool (populated in wrap_modules/post_init)
        self.module_offloaders: list[_ModuleOffloader] = []
        self.buffer_pool: StaticBufferPool | None = None
        self.total_offloaded_bytes = 0

    def wrap_modules(
        self,
        modules_generator: Generator[nn.Module, None, None],
    ) -> list[nn.Module]:
        """Wrap modules with prefetch offloading logic."""
        assert len(self.module_offloaders) == 0, (
            "wrap_modules should only be called once"
        )

        all_modules = []
        offload_modules = []

        for module_index, module in enumerate(modules_generator):
            all_modules.append(module)

            # Select layers to offload based on group pattern
            # Offload last num_in_group layers of each group_size
            if module_index % self.group_size >= self.group_size - self.num_in_group:
                if self.offload_params:
                    whitelist = [
                        name
                        for name, _ in module.named_parameters()
                        if any(f".{p}." in f".{name}." for p in self.offload_params)
                    ]
                else:
                    whitelist = [name for name, _ in module.named_parameters()]

                if not whitelist:
                    continue  # skip layers with no matching params

                offload_modules.append(module)
                self.module_offloaders.append(
                    _ModuleOffloader(
                        mode=self.mode,
                        module=module,
                        copy_stream=self.copy_stream,
                        whitelist_param_names=whitelist,
                        layer_idx=len(self.module_offloaders),
                    )
                )

        for index, module in enumerate(offload_modules):
            self._hook_module_forward(index, module)

        return all_modules

    def _hook_module_forward(self, index: int, module: nn.Module):
        """Hook module's forward with torch.compile-compatible sync."""
        original_forward = module.forward

        def forward(*args, **kwargs):
            # Temporarily restore original forward to avoid recursion
            module.forward = original_forward

            # Wait for this layer's prefetch to complete
            # mutates_args on input_tensor creates data dependency for torch.compile
            input_tensor = args[0] if args else kwargs.get("hidden_states")
            torch.ops.vllm.wait_prefetch(input_tensor, index)

            # No parameter swapping needed - parameters already point to
            # GPU static buffers (set in assign_static_buffer)
            output = original_forward(*args, **kwargs)

            # Start prefetch for next layer (circular)
            # mutates_args on output_tensor creates ordering dependency
            next_index = (index + self.prefetch_step) % len(self.module_offloaders)
            # Handle tuple output (e.g., (hidden_states, residual))
            if isinstance(output, tuple):
                torch.ops.vllm.start_prefetch(output[0], next_index)
            else:
                torch.ops.vllm.start_prefetch(output, next_index)

            # No explicit offload needed - static buffers are reused implicitly

            # Restore hooked forward
            module.forward = forward
            return output

        module.forward = forward

    def _wait_for_layer(self, layer_idx: int):
        """Called by custom op - wait for copy to complete.

        Synchronization strategy:
        - During CUDA graph capture: use event-based wait (graph-compatible)
        - Outside capture (warmup/eager): use wait_stream (more robust)

        During capture, we skip wait for pre-capture prefetches because:
        1. sync_before_graph_capture() ensures pre-capture work is complete
        2. We can't wait on pre-capture events during capture (isolation error)
        """
        offloader = self.module_offloaders[layer_idx]

        if torch.cuda.is_current_stream_capturing():
            # During capture, skip wait for pre-capture prefetches.
            # sync_before_graph_capture() ensures pre-capture work is complete.
            if not offloader._prefetch_in_capture:
                return
            # Event-based wait for in-capture prefetches (graph-compatible)
            torch.cuda.current_stream().wait_event(offloader._copy_done_event)
            # Mark that this prefetch has been waited on (joined).
            offloader._prefetch_in_capture = False
        else:
            if offloader._event_valid_for_eager:
                # Use per-layer event to only wait for THIS layer's copy,
                # allowing other layers' prefetches to run concurrently.
                torch.cuda.current_stream().wait_event(offloader._copy_done_event)
            else:
                # Event not usable (unrecorded or recorded during capture).
                # Fall back to wait_stream to drain all copy_stream work.
                torch.cuda.current_stream().wait_stream(self.copy_stream)

    def sync_prev_onload(self):
        """Sync previous onload operations.

        Ensures any H2D copies in flight on copy_stream complete before
        the compute stream continues. Call this before CUDA graph
        capture/replay or when synchronization is needed.
        """
        torch.cuda.current_stream().wait_stream(self.copy_stream)

    def _start_prefetch(self, layer_idx: int):
        """Called by custom op - start async copy to static buffer."""
        offloader = self.module_offloaders[layer_idx]
        offloader.start_onload_to_static()

    def join_after_forward(self):
        """Join copy_stream after model forward completes.

        Call this after the model forward pass but before CUDA graph capture
        ends. This ensures copy_stream is rejoined for any prefetches started
        during the forward pass.

        We join ALL layers that have _prefetch_in_capture=True, meaning their
        prefetch was started during capture but not yet waited on (joined).
        This handles both full and piecewise cudagraph modes correctly:
        - Full mode: joins layers 0..prefetch_step-1 (prefetched by last layers)
        - Piecewise mode: joins only layers prefetched by THIS subgraph's layers
        """
        if not self.module_offloaders:
            return
        # Join all layers whose prefetch was started in capture but not waited on
        for offloader in self.module_offloaders:
            if offloader._prefetch_in_capture:
                torch.cuda.current_stream().wait_event(offloader._copy_done_event)
                offloader._prefetch_in_capture = False

    def post_init(self):
        """Allocate static buffer pool and start initial prefetches.

        Note: Parameters have already been offloaded to CPU during wrap_modules()
        (in _CpuParamOffloader.__init__), so GPU memory is available for the
        static buffer pool.
        """
        # Sync CPU storage with current param.data BEFORE collecting param info.
        # This is needed because process_weights_after_loading may have:
        # 1. Transformed weights (quantization, transpose, etc.)
        # 2. Created new CPU tensors via device_loading_context
        # Our _cpu_storage would be stale otherwise.
        for offloader in self.module_offloaders:
            offloader.sync_cpu_storage()

        # Collect parameter info (now using synced CPU storage)
        param_infos: list[ParamInfo] = []
        device: torch.device | None = None

        for offloader in self.module_offloaders:
            param_infos.extend(offloader.get_param_infos())
            if device is None:
                device = offloader.device

        if device is None:
            # No modules to offload
            return

        # Allocate static buffer pool
        self.buffer_pool = StaticBufferPool(
            param_infos=param_infos,
            slot_capacity=self.prefetch_step,
            device=device,
        )

        # Assign buffer slots and point parameters to GPU buffers
        for idx, offloader in enumerate(self.module_offloaders):
            slot_idx = idx % self.prefetch_step
            offloader.assign_buffer_slot(self.buffer_pool, slot_idx)

        # Collect offloaded bytes
        for offloader in self.module_offloaders:
            offloader.post_init()
            self.total_offloaded_bytes += offloader.offloaded_bytes

        logger.info_once(
            f"[PrefetchOffloader] Initialized {len(self.module_offloaders)} modules. "
            f"Total GPU memory saved: {self.total_offloaded_bytes / 1e9:.4f} GB, "
            f"Static buffer pool: {self.buffer_pool.total_bytes / 1e9:.4f} GB "
            f"(group_size={self.group_size}, num_in_group={self.num_in_group}, "
            f"prefetch_step={self.prefetch_step}, mode={self.mode})"
        )

        # Start initial prefetches
        for i in range(min(self.prefetch_step, len(self.module_offloaders))):
            self.module_offloaders[i].start_onload_to_static()

_hook_module_forward

_hook_module_forward(index: int, module: Module)

Hook module's forward with torch.compile-compatible sync.

Source code in vllm/model_executor/offloader/prefetch.py
def _hook_module_forward(self, index: int, module: nn.Module):
    """Hook module's forward with torch.compile-compatible sync."""
    original_forward = module.forward

    def forward(*args, **kwargs):
        # Temporarily restore original forward to avoid recursion
        module.forward = original_forward

        # Wait for this layer's prefetch to complete
        # mutates_args on input_tensor creates data dependency for torch.compile
        input_tensor = args[0] if args else kwargs.get("hidden_states")
        torch.ops.vllm.wait_prefetch(input_tensor, index)

        # No parameter swapping needed - parameters already point to
        # GPU static buffers (set in assign_static_buffer)
        output = original_forward(*args, **kwargs)

        # Start prefetch for next layer (circular)
        # mutates_args on output_tensor creates ordering dependency
        next_index = (index + self.prefetch_step) % len(self.module_offloaders)
        # Handle tuple output (e.g., (hidden_states, residual))
        if isinstance(output, tuple):
            torch.ops.vllm.start_prefetch(output[0], next_index)
        else:
            torch.ops.vllm.start_prefetch(output, next_index)

        # No explicit offload needed - static buffers are reused implicitly

        # Restore hooked forward
        module.forward = forward
        return output

    module.forward = forward

_start_prefetch

_start_prefetch(layer_idx: int)

Called by custom op - start async copy to static buffer.

Source code in vllm/model_executor/offloader/prefetch.py
def _start_prefetch(self, layer_idx: int):
    """Called by custom op - start async copy to static buffer."""
    offloader = self.module_offloaders[layer_idx]
    offloader.start_onload_to_static()

_wait_for_layer

_wait_for_layer(layer_idx: int)

Called by custom op - wait for copy to complete.

Synchronization strategy: - During CUDA graph capture: use event-based wait (graph-compatible) - Outside capture (warmup/eager): use wait_stream (more robust)

During capture, we skip wait for pre-capture prefetches because: 1. sync_before_graph_capture() ensures pre-capture work is complete 2. We can't wait on pre-capture events during capture (isolation error)

Source code in vllm/model_executor/offloader/prefetch.py
def _wait_for_layer(self, layer_idx: int):
    """Called by custom op - wait for copy to complete.

    Synchronization strategy:
    - During CUDA graph capture: use event-based wait (graph-compatible)
    - Outside capture (warmup/eager): use wait_stream (more robust)

    During capture, we skip wait for pre-capture prefetches because:
    1. sync_before_graph_capture() ensures pre-capture work is complete
    2. We can't wait on pre-capture events during capture (isolation error)
    """
    offloader = self.module_offloaders[layer_idx]

    if torch.cuda.is_current_stream_capturing():
        # During capture, skip wait for pre-capture prefetches.
        # sync_before_graph_capture() ensures pre-capture work is complete.
        if not offloader._prefetch_in_capture:
            return
        # Event-based wait for in-capture prefetches (graph-compatible)
        torch.cuda.current_stream().wait_event(offloader._copy_done_event)
        # Mark that this prefetch has been waited on (joined).
        offloader._prefetch_in_capture = False
    else:
        if offloader._event_valid_for_eager:
            # Use per-layer event to only wait for THIS layer's copy,
            # allowing other layers' prefetches to run concurrently.
            torch.cuda.current_stream().wait_event(offloader._copy_done_event)
        else:
            # Event not usable (unrecorded or recorded during capture).
            # Fall back to wait_stream to drain all copy_stream work.
            torch.cuda.current_stream().wait_stream(self.copy_stream)

join_after_forward

join_after_forward()

Join copy_stream after model forward completes.

Call this after the model forward pass but before CUDA graph capture ends. This ensures copy_stream is rejoined for any prefetches started during the forward pass.

We join ALL layers that have _prefetch_in_capture=True, meaning their prefetch was started during capture but not yet waited on (joined). This handles both full and piecewise cudagraph modes correctly: - Full mode: joins layers 0..prefetch_step-1 (prefetched by last layers) - Piecewise mode: joins only layers prefetched by THIS subgraph's layers

Source code in vllm/model_executor/offloader/prefetch.py
def join_after_forward(self):
    """Join copy_stream after model forward completes.

    Call this after the model forward pass but before CUDA graph capture
    ends. This ensures copy_stream is rejoined for any prefetches started
    during the forward pass.

    We join ALL layers that have _prefetch_in_capture=True, meaning their
    prefetch was started during capture but not yet waited on (joined).
    This handles both full and piecewise cudagraph modes correctly:
    - Full mode: joins layers 0..prefetch_step-1 (prefetched by last layers)
    - Piecewise mode: joins only layers prefetched by THIS subgraph's layers
    """
    if not self.module_offloaders:
        return
    # Join all layers whose prefetch was started in capture but not waited on
    for offloader in self.module_offloaders:
        if offloader._prefetch_in_capture:
            torch.cuda.current_stream().wait_event(offloader._copy_done_event)
            offloader._prefetch_in_capture = False

post_init

post_init()

Allocate static buffer pool and start initial prefetches.

Note: Parameters have already been offloaded to CPU during wrap_modules() (in _CpuParamOffloader.init), so GPU memory is available for the static buffer pool.

Source code in vllm/model_executor/offloader/prefetch.py
def post_init(self):
    """Allocate static buffer pool and start initial prefetches.

    Note: Parameters have already been offloaded to CPU during wrap_modules()
    (in _CpuParamOffloader.__init__), so GPU memory is available for the
    static buffer pool.
    """
    # Sync CPU storage with current param.data BEFORE collecting param info.
    # This is needed because process_weights_after_loading may have:
    # 1. Transformed weights (quantization, transpose, etc.)
    # 2. Created new CPU tensors via device_loading_context
    # Our _cpu_storage would be stale otherwise.
    for offloader in self.module_offloaders:
        offloader.sync_cpu_storage()

    # Collect parameter info (now using synced CPU storage)
    param_infos: list[ParamInfo] = []
    device: torch.device | None = None

    for offloader in self.module_offloaders:
        param_infos.extend(offloader.get_param_infos())
        if device is None:
            device = offloader.device

    if device is None:
        # No modules to offload
        return

    # Allocate static buffer pool
    self.buffer_pool = StaticBufferPool(
        param_infos=param_infos,
        slot_capacity=self.prefetch_step,
        device=device,
    )

    # Assign buffer slots and point parameters to GPU buffers
    for idx, offloader in enumerate(self.module_offloaders):
        slot_idx = idx % self.prefetch_step
        offloader.assign_buffer_slot(self.buffer_pool, slot_idx)

    # Collect offloaded bytes
    for offloader in self.module_offloaders:
        offloader.post_init()
        self.total_offloaded_bytes += offloader.offloaded_bytes

    logger.info_once(
        f"[PrefetchOffloader] Initialized {len(self.module_offloaders)} modules. "
        f"Total GPU memory saved: {self.total_offloaded_bytes / 1e9:.4f} GB, "
        f"Static buffer pool: {self.buffer_pool.total_bytes / 1e9:.4f} GB "
        f"(group_size={self.group_size}, num_in_group={self.num_in_group}, "
        f"prefetch_step={self.prefetch_step}, mode={self.mode})"
    )

    # Start initial prefetches
    for i in range(min(self.prefetch_step, len(self.module_offloaders))):
        self.module_offloaders[i].start_onload_to_static()

sync_prev_onload

sync_prev_onload()

Sync previous onload operations.

Ensures any H2D copies in flight on copy_stream complete before the compute stream continues. Call this before CUDA graph capture/replay or when synchronization is needed.

Source code in vllm/model_executor/offloader/prefetch.py
def sync_prev_onload(self):
    """Sync previous onload operations.

    Ensures any H2D copies in flight on copy_stream complete before
    the compute stream continues. Call this before CUDA graph
    capture/replay or when synchronization is needed.
    """
    torch.cuda.current_stream().wait_stream(self.copy_stream)

wrap_modules

wrap_modules(
    modules_generator: Generator[Module, None, None],
) -> list[Module]

Wrap modules with prefetch offloading logic.

Source code in vllm/model_executor/offloader/prefetch.py
def wrap_modules(
    self,
    modules_generator: Generator[nn.Module, None, None],
) -> list[nn.Module]:
    """Wrap modules with prefetch offloading logic."""
    assert len(self.module_offloaders) == 0, (
        "wrap_modules should only be called once"
    )

    all_modules = []
    offload_modules = []

    for module_index, module in enumerate(modules_generator):
        all_modules.append(module)

        # Select layers to offload based on group pattern
        # Offload last num_in_group layers of each group_size
        if module_index % self.group_size >= self.group_size - self.num_in_group:
            if self.offload_params:
                whitelist = [
                    name
                    for name, _ in module.named_parameters()
                    if any(f".{p}." in f".{name}." for p in self.offload_params)
                ]
            else:
                whitelist = [name for name, _ in module.named_parameters()]

            if not whitelist:
                continue  # skip layers with no matching params

            offload_modules.append(module)
            self.module_offloaders.append(
                _ModuleOffloader(
                    mode=self.mode,
                    module=module,
                    copy_stream=self.copy_stream,
                    whitelist_param_names=whitelist,
                    layer_idx=len(self.module_offloaders),
                )
            )

    for index, module in enumerate(offload_modules):
        self._hook_module_forward(index, module)

    return all_modules

StaticBufferPool

Pre-allocated GPU buffer pool for offloaded parameters.

Allocates slot_capacity copies of each unique parameter (name, shape, stride, dtype), allowing for double/triple buffering during prefetch.

Buffer slots are reused circularly: layer N uses slot (N % slot_capacity).

The key includes parameter name to prevent different parameters within the same layer from sharing buffers. Parameters with the same name across different layers share buffers via the slot mechanism.

Source code in vllm/model_executor/offloader/prefetch.py
class StaticBufferPool:
    """Pre-allocated GPU buffer pool for offloaded parameters.

    Allocates slot_capacity copies of each unique parameter
    (name, shape, stride, dtype), allowing for double/triple buffering
    during prefetch.

    Buffer slots are reused circularly: layer N uses slot (N % slot_capacity).

    The key includes parameter name to prevent different parameters within
    the same layer from sharing buffers. Parameters with the same name
    across different layers share buffers via the slot mechanism.
    """

    def __init__(
        self,
        param_infos: list[ParamInfo],
        slot_capacity: int,
        device: torch.device,
    ):
        self.slot_capacity = slot_capacity
        self.total_bytes = 0
        self._device = device

        # Group by (shape, stride, dtype) - only allocate unique combinations
        unique_params: dict[tuple, ParamInfo] = {}
        for info in param_infos:
            if info.key not in unique_params:
                unique_params[info.key] = info

        # Allocate buffers: key -> list of tensors (one per slot)
        self._buffers: dict[tuple, list[torch.Tensor]] = {}
        for key, info in unique_params.items():
            slot_tensors = []
            for _ in range(slot_capacity):
                # Use empty_strided to preserve parameter's memory layout
                buf = torch.empty_strided(
                    size=info.shape,
                    stride=info.stride,
                    dtype=info.dtype,
                    device=device,
                )
                slot_tensors.append(buf)
                self.total_bytes += info.num_bytes
            self._buffers[key] = slot_tensors

        logger.debug(
            "[StaticBufferPool] Allocated %d unique (name, shape, stride, dtype), "
            "%d slots each, total %.4f GB",
            len(unique_params),
            slot_capacity,
            self.total_bytes / 1e9,
        )

    def get_buffer(
        self,
        name: str,
        shape: tuple[int, ...],
        stride: tuple[int, ...],
        dtype: torch.dtype,
        slot_idx: int,
    ) -> torch.Tensor:
        """Get a static buffer for the given name/shape/stride/dtype/slot."""
        key = (name, shape, stride, dtype)
        return self._buffers[key][slot_idx % self.slot_capacity]

get_buffer

get_buffer(
    name: str,
    shape: tuple[int, ...],
    stride: tuple[int, ...],
    dtype: dtype,
    slot_idx: int,
) -> Tensor

Get a static buffer for the given name/shape/stride/dtype/slot.

Source code in vllm/model_executor/offloader/prefetch.py
def get_buffer(
    self,
    name: str,
    shape: tuple[int, ...],
    stride: tuple[int, ...],
    dtype: torch.dtype,
    slot_idx: int,
) -> torch.Tensor:
    """Get a static buffer for the given name/shape/stride/dtype/slot."""
    key = (name, shape, stride, dtype)
    return self._buffers[key][slot_idx % self.slot_capacity]

_BaseParamOffloader

Bases: ABC

Base class for parameter offloading strategies.

Source code in vllm/model_executor/offloader/prefetch.py
class _BaseParamOffloader(ABC):
    """Base class for parameter offloading strategies."""

    # CPU storage for offloaded parameters (set by subclasses)
    _cpu_storage: torch.Tensor | None
    # GPU buffer reference (set by subclasses when using static buffers)
    _gpu_buffer: torch.Tensor | None

    @staticmethod
    def create(mode: str, **kwargs) -> "_BaseParamOffloader":
        """Factory method to create appropriate offloader for mode."""
        if mode == "cpu":
            return _CpuParamOffloader(**kwargs)
        else:
            raise ValueError(f"Unknown offload mode: {mode}")

    def __init__(self, module: nn.Module, param_name: str):
        self._module = module
        self._param_name = param_name
        self.offloaded_bytes = 0
        self._cpu_storage = None
        self._gpu_buffer = None

    @property
    def _param(self) -> nn.Parameter:
        """Get the parameter being offloaded.

        Supports dotted names (e.g. 'self_attn.qkv_proj.weight') by
        traversing the module hierarchy.
        """
        obj: Any = self._module
        for attr in self._param_name.split("."):
            obj = getattr(obj, attr)
        return obj

    def post_init(self):
        """Initialize offloading (move parameter to storage)."""
        return

    @abstractmethod
    def sync_cpu_storage(self) -> None:
        """Sync CPU storage with current param.data.

        Called after process_weights_after_loading to update _cpu_storage
        with the final processed weights.
        """
        pass

    @abstractmethod
    def assign_static_buffer(self, gpu_buffer: torch.Tensor) -> None:
        """Point parameter data to GPU static buffer."""
        pass

_param property

_param: Parameter

Get the parameter being offloaded.

Supports dotted names (e.g. 'self_attn.qkv_proj.weight') by traversing the module hierarchy.

assign_static_buffer abstractmethod

assign_static_buffer(gpu_buffer: Tensor) -> None

Point parameter data to GPU static buffer.

Source code in vllm/model_executor/offloader/prefetch.py
@abstractmethod
def assign_static_buffer(self, gpu_buffer: torch.Tensor) -> None:
    """Point parameter data to GPU static buffer."""
    pass

create staticmethod

create(mode: str, **kwargs) -> _BaseParamOffloader

Factory method to create appropriate offloader for mode.

Source code in vllm/model_executor/offloader/prefetch.py
@staticmethod
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
    """Factory method to create appropriate offloader for mode."""
    if mode == "cpu":
        return _CpuParamOffloader(**kwargs)
    else:
        raise ValueError(f"Unknown offload mode: {mode}")

post_init

post_init()

Initialize offloading (move parameter to storage).

Source code in vllm/model_executor/offloader/prefetch.py
def post_init(self):
    """Initialize offloading (move parameter to storage)."""
    return

sync_cpu_storage abstractmethod

sync_cpu_storage() -> None

Sync CPU storage with current param.data.

Called after process_weights_after_loading to update _cpu_storage with the final processed weights.

Source code in vllm/model_executor/offloader/prefetch.py
@abstractmethod
def sync_cpu_storage(self) -> None:
    """Sync CPU storage with current param.data.

    Called after process_weights_after_loading to update _cpu_storage
    with the final processed weights.
    """
    pass

_CpuParamOffloader

Bases: _BaseParamOffloader

Offload parameter to pinned CPU memory.

Uses GPU static buffers as the actual parameter, with CPU storage kept separately. This ensures torch.compile sees GPU tensors at trace time.

The offloading happens in two phases: 1. init() - copies GPU data to CPU, frees GPU memory immediately 2. assign_static_buffer() - points param.data to GPU static buffer

Source code in vllm/model_executor/offloader/prefetch.py
class _CpuParamOffloader(_BaseParamOffloader):
    """Offload parameter to pinned CPU memory.

    Uses GPU static buffers as the actual parameter, with CPU storage
    kept separately. This ensures torch.compile sees GPU tensors at trace time.

    The offloading happens in two phases:
    1. __init__() - copies GPU data to CPU, frees GPU memory immediately
    2. assign_static_buffer() - points param.data to GPU static buffer
    """

    def __init__(self, module: nn.Module, param_name: str):
        super().__init__(module, param_name)
        self._cpu_storage: torch.Tensor | None = None
        self._gpu_buffer: torch.Tensor | None = None  # Store reference to GPU buffer

        # Offload to CPU immediately to free GPU memory during model loading
        self._offload_to_cpu_internal()

    def _offload_to_cpu_internal(self):
        """Copy parameter data to pinned CPU storage and free GPU memory.

        This replaces param.data with CPU storage, allowing weight loading
        to continue writing to CPU memory. GPU memory is freed when the
        original GPU tensor is garbage collected.
        """
        param = self._param
        pin_memory = is_pin_memory_available()

        # Create pinned CPU storage and copy current GPU data
        self._cpu_storage = torch.empty_strided(
            size=param.data.size(),
            stride=param.data.stride(),
            dtype=param.data.dtype,
            layout=param.data.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
        self._cpu_storage.copy_(param.data)

        self.offloaded_bytes = (
            self._cpu_storage.numel() * self._cpu_storage.element_size()
        )

        # Point param.data to CPU storage - this allows weight loading to work
        # and frees GPU memory when the original GPU tensor is garbage collected
        param.data = self._cpu_storage

    def _update_cpu_storage_from_param(self) -> None:
        """Update _cpu_storage from current param.data, ensuring pinned memory.

        After process_weights_after_loading, device_loading_context creates
        non-pinned CPU tensors via `p.data = p.data.to("cpu")`. Using
        non-pinned memory with `copy_(src, non_blocking=True)` causes CUDA to
        perform a stream synchronization before the copy, breaking the
        event-based fork synchronization and potentially allowing the copy
        to overwrite the GPU buffer while the compute stream still reads it.

        This method ensures _cpu_storage always uses pinned memory when
        available, re-pinning if necessary.
        """
        param = self._param

        if param.data.device.type == "cpu":
            if is_pin_memory_available() and not param.data.is_pinned():
                pinned = torch.empty_strided(
                    size=param.data.size(),
                    stride=param.data.stride(),
                    dtype=param.data.dtype,
                    layout=param.data.layout,
                    device="cpu",
                    pin_memory=True,
                )
                pinned.copy_(param.data)
                self._cpu_storage = pinned
            else:
                self._cpu_storage = param.data
        else:
            # param.data is on GPU - copy to existing CPU storage
            assert self._cpu_storage is not None
            self._cpu_storage.copy_(param.data)

    def assign_static_buffer(self, gpu_buffer: torch.Tensor) -> None:
        """Point parameter data to GPU static buffer.

        This is called after weight loading AND process_weights_after_loading
        complete. At this point:
        - param.data may have been replaced by device_loading_context
          (which creates new CPU tensors after quantization processing)
        - We need to update _cpu_storage to point to current param.data
          so that prefetch copies the processed weights, not stale data
        - Then point param.data to the GPU buffer for torch.compile
        """
        assert self._cpu_storage is not None, (
            "_offload_to_cpu_internal() must be called before assign_static_buffer()"
        )

        # Get current parameter (may have been replaced by
        # process_weights_after_loading)
        param = self._param

        # Update _cpu_storage to current param.data. This is critical because:
        # 1. process_weights_after_loading may transform weights (quantization)
        # 2. device_loading_context creates NEW CPU tensors when moving back
        # 3. Our old _cpu_storage would have pre-processed or stale data
        self._update_cpu_storage_from_param()

        # Store reference to GPU buffer for use in start_onload
        self._gpu_buffer = gpu_buffer

        # Point parameter to static GPU buffer - this is what torch.compile sees
        param.data = gpu_buffer

    def sync_cpu_storage(self) -> None:
        """Sync CPU storage with current param.data.

        Called after process_weights_after_loading to update _cpu_storage
        with the final processed weights. This is critical because:
        1. process_weights_after_loading may transform weights (quantization)
        2. device_loading_context creates NEW CPU tensors when moving back
        3. Our old _cpu_storage would have pre-processed or stale data
        """
        self._update_cpu_storage_from_param()

    def post_init(self):
        """No-op: offloading done in offload_to_cpu/assign_static_buffer."""
        pass

_offload_to_cpu_internal

_offload_to_cpu_internal()

Copy parameter data to pinned CPU storage and free GPU memory.

This replaces param.data with CPU storage, allowing weight loading to continue writing to CPU memory. GPU memory is freed when the original GPU tensor is garbage collected.

Source code in vllm/model_executor/offloader/prefetch.py
def _offload_to_cpu_internal(self):
    """Copy parameter data to pinned CPU storage and free GPU memory.

    This replaces param.data with CPU storage, allowing weight loading
    to continue writing to CPU memory. GPU memory is freed when the
    original GPU tensor is garbage collected.
    """
    param = self._param
    pin_memory = is_pin_memory_available()

    # Create pinned CPU storage and copy current GPU data
    self._cpu_storage = torch.empty_strided(
        size=param.data.size(),
        stride=param.data.stride(),
        dtype=param.data.dtype,
        layout=param.data.layout,
        device="cpu",
        pin_memory=pin_memory,
    )
    self._cpu_storage.copy_(param.data)

    self.offloaded_bytes = (
        self._cpu_storage.numel() * self._cpu_storage.element_size()
    )

    # Point param.data to CPU storage - this allows weight loading to work
    # and frees GPU memory when the original GPU tensor is garbage collected
    param.data = self._cpu_storage

_update_cpu_storage_from_param

_update_cpu_storage_from_param() -> None

Update _cpu_storage from current param.data, ensuring pinned memory.

After process_weights_after_loading, device_loading_context creates non-pinned CPU tensors via p.data = p.data.to("cpu"). Using non-pinned memory with copy_(src, non_blocking=True) causes CUDA to perform a stream synchronization before the copy, breaking the event-based fork synchronization and potentially allowing the copy to overwrite the GPU buffer while the compute stream still reads it.

This method ensures _cpu_storage always uses pinned memory when available, re-pinning if necessary.

Source code in vllm/model_executor/offloader/prefetch.py
def _update_cpu_storage_from_param(self) -> None:
    """Update _cpu_storage from current param.data, ensuring pinned memory.

    After process_weights_after_loading, device_loading_context creates
    non-pinned CPU tensors via `p.data = p.data.to("cpu")`. Using
    non-pinned memory with `copy_(src, non_blocking=True)` causes CUDA to
    perform a stream synchronization before the copy, breaking the
    event-based fork synchronization and potentially allowing the copy
    to overwrite the GPU buffer while the compute stream still reads it.

    This method ensures _cpu_storage always uses pinned memory when
    available, re-pinning if necessary.
    """
    param = self._param

    if param.data.device.type == "cpu":
        if is_pin_memory_available() and not param.data.is_pinned():
            pinned = torch.empty_strided(
                size=param.data.size(),
                stride=param.data.stride(),
                dtype=param.data.dtype,
                layout=param.data.layout,
                device="cpu",
                pin_memory=True,
            )
            pinned.copy_(param.data)
            self._cpu_storage = pinned
        else:
            self._cpu_storage = param.data
    else:
        # param.data is on GPU - copy to existing CPU storage
        assert self._cpu_storage is not None
        self._cpu_storage.copy_(param.data)

assign_static_buffer

assign_static_buffer(gpu_buffer: Tensor) -> None

Point parameter data to GPU static buffer.

This is called after weight loading AND process_weights_after_loading complete. At this point: - param.data may have been replaced by device_loading_context (which creates new CPU tensors after quantization processing) - We need to update _cpu_storage to point to current param.data so that prefetch copies the processed weights, not stale data - Then point param.data to the GPU buffer for torch.compile

Source code in vllm/model_executor/offloader/prefetch.py
def assign_static_buffer(self, gpu_buffer: torch.Tensor) -> None:
    """Point parameter data to GPU static buffer.

    This is called after weight loading AND process_weights_after_loading
    complete. At this point:
    - param.data may have been replaced by device_loading_context
      (which creates new CPU tensors after quantization processing)
    - We need to update _cpu_storage to point to current param.data
      so that prefetch copies the processed weights, not stale data
    - Then point param.data to the GPU buffer for torch.compile
    """
    assert self._cpu_storage is not None, (
        "_offload_to_cpu_internal() must be called before assign_static_buffer()"
    )

    # Get current parameter (may have been replaced by
    # process_weights_after_loading)
    param = self._param

    # Update _cpu_storage to current param.data. This is critical because:
    # 1. process_weights_after_loading may transform weights (quantization)
    # 2. device_loading_context creates NEW CPU tensors when moving back
    # 3. Our old _cpu_storage would have pre-processed or stale data
    self._update_cpu_storage_from_param()

    # Store reference to GPU buffer for use in start_onload
    self._gpu_buffer = gpu_buffer

    # Point parameter to static GPU buffer - this is what torch.compile sees
    param.data = gpu_buffer

post_init

post_init()

No-op: offloading done in offload_to_cpu/assign_static_buffer.

Source code in vllm/model_executor/offloader/prefetch.py
def post_init(self):
    """No-op: offloading done in offload_to_cpu/assign_static_buffer."""
    pass

sync_cpu_storage

sync_cpu_storage() -> None

Sync CPU storage with current param.data.

Called after process_weights_after_loading to update _cpu_storage with the final processed weights. This is critical because: 1. process_weights_after_loading may transform weights (quantization) 2. device_loading_context creates NEW CPU tensors when moving back 3. Our old _cpu_storage would have pre-processed or stale data

Source code in vllm/model_executor/offloader/prefetch.py
def sync_cpu_storage(self) -> None:
    """Sync CPU storage with current param.data.

    Called after process_weights_after_loading to update _cpu_storage
    with the final processed weights. This is critical because:
    1. process_weights_after_loading may transform weights (quantization)
    2. device_loading_context creates NEW CPU tensors when moving back
    3. Our old _cpu_storage would have pre-processed or stale data
    """
    self._update_cpu_storage_from_param()

_ModuleOffloader

Manages offloading for a single module.

Uses static buffers from a shared pool instead of dynamic allocation.

Source code in vllm/model_executor/offloader/prefetch.py
class _ModuleOffloader:
    """Manages offloading for a single module.

    Uses static buffers from a shared pool instead of dynamic allocation.
    """

    def __init__(
        self,
        mode: str,
        module: nn.Module,
        copy_stream: torch.cuda.Stream,
        whitelist_param_names: list[str],
        layer_idx: int,
    ):
        self.mode = mode
        self.module = module
        self.device = next(module.parameters()).device
        self.copy_stream = copy_stream
        self.layer_idx = layer_idx
        self.offloaded_bytes = 0

        # Event to signal when H2D copy to static buffer is complete.
        # Used for per-layer synchronization (both eager and capture modes).
        self._copy_done_event = torch.cuda.Event()

        # Track whether _copy_done_event is valid for eager-mode wait_event.
        # False when: (1) never recorded, or (2) last recorded during a
        # cudagraph capture (events become invalid after capture ends).
        # In these cases we fall back to wait_stream.
        self._event_valid_for_eager = False

        # Track if last prefetch was started during CUDA graph capture.
        # Used to skip wait_event during capture for pre-capture prefetches.
        self._prefetch_in_capture = False

        assert self.device != torch.device("cpu"), (
            "Module parameters should not already be on CPU "
            "(offloader handles CPU placement)"
        )

        # Buffer pool and slot (assigned in assign_buffer_slot)
        self._buffer_pool: StaticBufferPool | None = None
        self._buffer_slot_idx: int = 0

        param_dict = dict(self.module.named_parameters())
        assert all(name in param_dict for name in whitelist_param_names), (
            f"Whitelist params {whitelist_param_names} not found in module params "
            f"{list(param_dict.keys())}"
        )

        self._param_offloaders = {
            name: _BaseParamOffloader.create(mode, module=module, param_name=name)
            for name in whitelist_param_names
        }

    def post_init(self):
        """Collect total offloaded bytes (offloading already done in __init__)."""
        for param_offloader in self._param_offloaders.values():
            param_offloader.post_init()
            self.offloaded_bytes += param_offloader.offloaded_bytes

    def sync_cpu_storage(self):
        """Sync CPU storage with current param.data.

        Called after process_weights_after_loading to ensure _cpu_storage
        contains the final processed weights, not stale pre-loading data.
        """
        for param_offloader in self._param_offloaders.values():
            param_offloader.sync_cpu_storage()

    def get_param_infos(self) -> list[ParamInfo]:
        """Get parameter metadata for buffer pool allocation.

        Note: sync_cpu_storage() must be called before this method to ensure
        _cpu_storage reflects the final processed weights (after quantization).
        """
        infos = []
        for name, offloader in self._param_offloaders.items():
            cpu_storage = offloader._cpu_storage
            assert cpu_storage is not None, "CPU storage not initialized"
            infos.append(
                ParamInfo(
                    name=name,
                    shape=tuple(cpu_storage.shape),
                    stride=tuple(cpu_storage.stride()),
                    dtype=cpu_storage.dtype,
                )
            )
        return infos

    def assign_buffer_slot(self, pool: StaticBufferPool, slot_idx: int):
        """Assign this module to a buffer slot in the pool.

        Also assigns static GPU buffers to each parameter offloader,
        which moves the parameter data to point to the GPU buffer.
        """
        self._buffer_pool = pool
        self._buffer_slot_idx = slot_idx

        # Assign static buffers to parameters
        # Use CPU storage shape/stride/dtype since param.data is now empty
        for name, offloader in self._param_offloaders.items():
            cpu_storage = offloader._cpu_storage
            assert cpu_storage is not None, "CPU storage not initialized"
            buffer = pool.get_buffer(
                name=name,
                shape=tuple(cpu_storage.shape),
                stride=tuple(cpu_storage.stride()),
                dtype=cpu_storage.dtype,
                slot_idx=slot_idx,
            )
            offloader.assign_static_buffer(buffer)

    def start_onload_to_static(self):
        """Start async copy from CPU storage to GPU buffer.

        Uses event-based forking to join copy_stream to CUDA graph capture.
        This ensures H2D copies are properly captured when recording a graph.

        IMPORTANT: We must wait for the compute stream before copying, because
        the previous layer's forward may still be using the buffer (GPU ops are
        async). Without this sync, we could overwrite the buffer while it's
        being read.
        """
        assert self._buffer_pool is not None, "Buffer pool not assigned"

        # Track if this prefetch is being captured (for _wait_for_layer logic)
        self._prefetch_in_capture = torch.cuda.is_current_stream_capturing()

        # Fork: record event on compute stream, copy_stream waits on it
        # This joins copy_stream to any active CUDA graph capture
        fork_event = torch.cuda.Event()
        torch.cuda.current_stream().record_event(fork_event)
        self.copy_stream.wait_event(fork_event)

        with torch.cuda.stream(self.copy_stream):
            for name, offloader in self._param_offloaders.items():
                cpu_storage = offloader._cpu_storage
                gpu_buffer = offloader._gpu_buffer
                assert cpu_storage is not None, "CPU storage not initialized"
                assert gpu_buffer is not None, "GPU buffer not assigned"
                assert not is_pin_memory_available() or cpu_storage.is_pinned(), (
                    f"CPU storage for {name} is not pinned! "
                    "non_blocking=True H2D copy from non-pinned memory "
                    "causes stream synchronization that breaks "
                    "event-based fork synchronization."
                )
                gpu_buffer.copy_(cpu_storage, non_blocking=True)

        # Record completion event for _wait_for_layer to use
        self._copy_done_event.record(self.copy_stream)
        # Event is only valid for eager wait_event if recorded outside capture.
        # Events recorded during capture become invalid after capture ends.
        self._event_valid_for_eager = not torch.cuda.is_current_stream_capturing()

assign_buffer_slot

assign_buffer_slot(pool: StaticBufferPool, slot_idx: int)

Assign this module to a buffer slot in the pool.

Also assigns static GPU buffers to each parameter offloader, which moves the parameter data to point to the GPU buffer.

Source code in vllm/model_executor/offloader/prefetch.py
def assign_buffer_slot(self, pool: StaticBufferPool, slot_idx: int):
    """Assign this module to a buffer slot in the pool.

    Also assigns static GPU buffers to each parameter offloader,
    which moves the parameter data to point to the GPU buffer.
    """
    self._buffer_pool = pool
    self._buffer_slot_idx = slot_idx

    # Assign static buffers to parameters
    # Use CPU storage shape/stride/dtype since param.data is now empty
    for name, offloader in self._param_offloaders.items():
        cpu_storage = offloader._cpu_storage
        assert cpu_storage is not None, "CPU storage not initialized"
        buffer = pool.get_buffer(
            name=name,
            shape=tuple(cpu_storage.shape),
            stride=tuple(cpu_storage.stride()),
            dtype=cpu_storage.dtype,
            slot_idx=slot_idx,
        )
        offloader.assign_static_buffer(buffer)

get_param_infos

get_param_infos() -> list[ParamInfo]

Get parameter metadata for buffer pool allocation.

Note: sync_cpu_storage() must be called before this method to ensure _cpu_storage reflects the final processed weights (after quantization).

Source code in vllm/model_executor/offloader/prefetch.py
def get_param_infos(self) -> list[ParamInfo]:
    """Get parameter metadata for buffer pool allocation.

    Note: sync_cpu_storage() must be called before this method to ensure
    _cpu_storage reflects the final processed weights (after quantization).
    """
    infos = []
    for name, offloader in self._param_offloaders.items():
        cpu_storage = offloader._cpu_storage
        assert cpu_storage is not None, "CPU storage not initialized"
        infos.append(
            ParamInfo(
                name=name,
                shape=tuple(cpu_storage.shape),
                stride=tuple(cpu_storage.stride()),
                dtype=cpu_storage.dtype,
            )
        )
    return infos

post_init

post_init()

Collect total offloaded bytes (offloading already done in init).

Source code in vllm/model_executor/offloader/prefetch.py
def post_init(self):
    """Collect total offloaded bytes (offloading already done in __init__)."""
    for param_offloader in self._param_offloaders.values():
        param_offloader.post_init()
        self.offloaded_bytes += param_offloader.offloaded_bytes

start_onload_to_static

start_onload_to_static()

Start async copy from CPU storage to GPU buffer.

Uses event-based forking to join copy_stream to CUDA graph capture. This ensures H2D copies are properly captured when recording a graph.

IMPORTANT: We must wait for the compute stream before copying, because the previous layer's forward may still be using the buffer (GPU ops are async). Without this sync, we could overwrite the buffer while it's being read.

Source code in vllm/model_executor/offloader/prefetch.py
def start_onload_to_static(self):
    """Start async copy from CPU storage to GPU buffer.

    Uses event-based forking to join copy_stream to CUDA graph capture.
    This ensures H2D copies are properly captured when recording a graph.

    IMPORTANT: We must wait for the compute stream before copying, because
    the previous layer's forward may still be using the buffer (GPU ops are
    async). Without this sync, we could overwrite the buffer while it's
    being read.
    """
    assert self._buffer_pool is not None, "Buffer pool not assigned"

    # Track if this prefetch is being captured (for _wait_for_layer logic)
    self._prefetch_in_capture = torch.cuda.is_current_stream_capturing()

    # Fork: record event on compute stream, copy_stream waits on it
    # This joins copy_stream to any active CUDA graph capture
    fork_event = torch.cuda.Event()
    torch.cuda.current_stream().record_event(fork_event)
    self.copy_stream.wait_event(fork_event)

    with torch.cuda.stream(self.copy_stream):
        for name, offloader in self._param_offloaders.items():
            cpu_storage = offloader._cpu_storage
            gpu_buffer = offloader._gpu_buffer
            assert cpu_storage is not None, "CPU storage not initialized"
            assert gpu_buffer is not None, "GPU buffer not assigned"
            assert not is_pin_memory_available() or cpu_storage.is_pinned(), (
                f"CPU storage for {name} is not pinned! "
                "non_blocking=True H2D copy from non-pinned memory "
                "causes stream synchronization that breaks "
                "event-based fork synchronization."
            )
            gpu_buffer.copy_(cpu_storage, non_blocking=True)

    # Record completion event for _wait_for_layer to use
    self._copy_done_event.record(self.copy_stream)
    # Event is only valid for eager wait_event if recorded outside capture.
    # Events recorded during capture become invalid after capture ends.
    self._event_valid_for_eager = not torch.cuda.is_current_stream_capturing()

sync_cpu_storage

sync_cpu_storage()

Sync CPU storage with current param.data.

Called after process_weights_after_loading to ensure _cpu_storage contains the final processed weights, not stale pre-loading data.

Source code in vllm/model_executor/offloader/prefetch.py
def sync_cpu_storage(self):
    """Sync CPU storage with current param.data.

    Called after process_weights_after_loading to ensure _cpu_storage
    contains the final processed weights, not stale pre-loading data.
    """
    for param_offloader in self._param_offloaders.values():
        param_offloader.sync_cpu_storage()