Skip to content

vllm.model_executor.offloader

Model parameter offloading infrastructure.

Modules:

Name Description
base

Base classes for model parameter offloading.

prefetch

Prefetch-based CPU offloading with async prefetching.

prefetch_ops

Custom ops for prefetch offloader torch.compile + CUDA graph compatibility.

uva

UVA-based CPU offloading using Unified Virtual Addressing.

BaseOffloader

Bases: ABC

Base class for model parameter offloading strategies.

Offloaders control how model parameters are stored and loaded during inference. Different strategies trade memory for compute/transfer time.

Source code in vllm/model_executor/offloader/base.py
class BaseOffloader(ABC):
    """Base class for model parameter offloading strategies.

    Offloaders control how model parameters are stored and loaded during
    inference. Different strategies trade memory for compute/transfer time.
    """

    @abstractmethod
    def wrap_modules(
        self,
        modules_generator: Generator[nn.Module, None, None],
    ) -> list[nn.Module]:
        """Wrap modules with offloading logic.

        Args:
            modules_generator: Generator yielding modules to potentially offload.

        Returns:
            List of modules, potentially with offloading hooks installed.
        """
        pass

    def post_init(self):
        """Called after model construction completes.

        Offloaders can use this to:
        - Finalize parameter storage
        - Start initial prefetching
        - Allocate shared resources
        """
        return

    def sync_prev_onload(self) -> None:  # noqa: B027
        """Sync previous onload operations. Override in subclasses."""
        pass

    def join_after_forward(self) -> None:  # noqa: B027
        """Join streams after forward. Override in subclasses."""
        pass

    def _wait_for_layer(self, layer_idx: int) -> None:  # noqa: B027
        """Wait for layer prefetch. Override in subclasses."""
        pass

    def _start_prefetch(self, layer_idx: int) -> None:  # noqa: B027
        """Start layer prefetch. Override in subclasses."""
        pass

_start_prefetch

_start_prefetch(layer_idx: int) -> None

Start layer prefetch. Override in subclasses.

Source code in vllm/model_executor/offloader/base.py
def _start_prefetch(self, layer_idx: int) -> None:  # noqa: B027
    """Start layer prefetch. Override in subclasses."""
    pass

_wait_for_layer

_wait_for_layer(layer_idx: int) -> None

Wait for layer prefetch. Override in subclasses.

Source code in vllm/model_executor/offloader/base.py
def _wait_for_layer(self, layer_idx: int) -> None:  # noqa: B027
    """Wait for layer prefetch. Override in subclasses."""
    pass

join_after_forward

join_after_forward() -> None

Join streams after forward. Override in subclasses.

Source code in vllm/model_executor/offloader/base.py
def join_after_forward(self) -> None:  # noqa: B027
    """Join streams after forward. Override in subclasses."""
    pass

post_init

post_init()

Called after model construction completes.

Offloaders can use this to: - Finalize parameter storage - Start initial prefetching - Allocate shared resources

Source code in vllm/model_executor/offloader/base.py
def post_init(self):
    """Called after model construction completes.

    Offloaders can use this to:
    - Finalize parameter storage
    - Start initial prefetching
    - Allocate shared resources
    """
    return

sync_prev_onload

sync_prev_onload() -> None

Sync previous onload operations. Override in subclasses.

Source code in vllm/model_executor/offloader/base.py
def sync_prev_onload(self) -> None:  # noqa: B027
    """Sync previous onload operations. Override in subclasses."""
    pass

wrap_modules abstractmethod

wrap_modules(
    modules_generator: Generator[Module, None, None],
) -> list[Module]

Wrap modules with offloading logic.

Parameters:

Name Type Description Default
modules_generator Generator[Module, None, None]

Generator yielding modules to potentially offload.

required

Returns:

Type Description
list[Module]

List of modules, potentially with offloading hooks installed.

Source code in vllm/model_executor/offloader/base.py
@abstractmethod
def wrap_modules(
    self,
    modules_generator: Generator[nn.Module, None, None],
) -> list[nn.Module]:
    """Wrap modules with offloading logic.

    Args:
        modules_generator: Generator yielding modules to potentially offload.

    Returns:
        List of modules, potentially with offloading hooks installed.
    """
    pass

NoopOffloader

Bases: BaseOffloader

No-op offloader that returns modules as-is without any offloading.

Source code in vllm/model_executor/offloader/base.py
class NoopOffloader(BaseOffloader):
    """No-op offloader that returns modules as-is without any offloading."""

    def wrap_modules(
        self,
        modules_generator: Generator[nn.Module, None, None],
    ) -> list[nn.Module]:
        """Return modules unchanged."""
        return list(modules_generator)

wrap_modules

wrap_modules(
    modules_generator: Generator[Module, None, None],
) -> list[Module]

Return modules unchanged.

Source code in vllm/model_executor/offloader/base.py
def wrap_modules(
    self,
    modules_generator: Generator[nn.Module, None, None],
) -> list[nn.Module]:
    """Return modules unchanged."""
    return list(modules_generator)

PrefetchOffloader

Bases: BaseOffloader

Prefetching-based offloader with group-based layer selection.

Groups layers and uses async H2D prefetch to hide transfer latency. Uses static buffers and stream synchronization for torch.compile and CUDA graph compatibility.

Parameters:

Name Type Description Default
group_size int

Group every N layers together.

required
num_in_group int

Offload this many layers per group (last N of each group).

required
prefetch_step int

Number of layers to prefetch ahead.

required
mode str

Offload mode ("cpu" is currently supported).

'cpu'
Source code in vllm/model_executor/offloader/prefetch.py
class PrefetchOffloader(BaseOffloader):
    """Prefetching-based offloader with group-based layer selection.

    Groups layers and uses async H2D prefetch to hide transfer latency.
    Uses static buffers and stream synchronization for torch.compile and
    CUDA graph compatibility.

    Args:
        group_size: Group every N layers together.
        num_in_group: Offload this many layers per group (last N of each group).
        prefetch_step: Number of layers to prefetch ahead.
        mode: Offload mode ("cpu" is currently supported).
    """

    def __init__(
        self,
        group_size: int,
        num_in_group: int,
        prefetch_step: int,
        offload_params: set[str] | None = None,
        mode: str = "cpu",
    ):
        self.group_size = group_size
        self.num_in_group = num_in_group
        self.prefetch_step = prefetch_step
        self.offload_params = offload_params or set()
        self.mode = mode

        # Copy stream for async H2D transfers
        self.copy_stream = torch.cuda.Stream()

        # Module offloaders and buffer pool (populated in wrap_modules/post_init)
        self.module_offloaders: list[_ModuleOffloader] = []
        self.buffer_pool: StaticBufferPool | None = None
        self.total_offloaded_bytes = 0

    def wrap_modules(
        self,
        modules_generator: Generator[nn.Module, None, None],
    ) -> list[nn.Module]:
        """Wrap modules with prefetch offloading logic."""
        assert len(self.module_offloaders) == 0, (
            "wrap_modules should only be called once"
        )

        all_modules = []
        offload_modules = []

        for module_index, module in enumerate(modules_generator):
            all_modules.append(module)

            # Select layers to offload based on group pattern
            # Offload last num_in_group layers of each group_size
            if module_index % self.group_size >= self.group_size - self.num_in_group:
                if self.offload_params:
                    whitelist = [
                        name
                        for name, _ in module.named_parameters()
                        if any(f".{p}." in f".{name}." for p in self.offload_params)
                    ]
                else:
                    whitelist = [name for name, _ in module.named_parameters()]

                if not whitelist:
                    continue  # skip layers with no matching params

                offload_modules.append(module)
                self.module_offloaders.append(
                    _ModuleOffloader(
                        mode=self.mode,
                        module=module,
                        copy_stream=self.copy_stream,
                        whitelist_param_names=whitelist,
                        layer_idx=len(self.module_offloaders),
                    )
                )

        for index, module in enumerate(offload_modules):
            self._hook_module_forward(index, module)

        return all_modules

    def _hook_module_forward(self, index: int, module: nn.Module):
        """Hook module's forward with torch.compile-compatible sync."""
        original_forward = module.forward

        def forward(*args, **kwargs):
            # Temporarily restore original forward to avoid recursion
            module.forward = original_forward

            # Wait for this layer's prefetch to complete
            # mutates_args on input_tensor creates data dependency for torch.compile
            input_tensor = args[0] if args else kwargs.get("hidden_states")
            torch.ops.vllm.wait_prefetch(input_tensor, index)

            # No parameter swapping needed - parameters already point to
            # GPU static buffers (set in assign_static_buffer)
            output = original_forward(*args, **kwargs)

            # Start prefetch for next layer (circular)
            # mutates_args on output_tensor creates ordering dependency
            next_index = (index + self.prefetch_step) % len(self.module_offloaders)
            # Handle tuple output (e.g., (hidden_states, residual))
            if isinstance(output, tuple):
                torch.ops.vllm.start_prefetch(output[0], next_index)
            else:
                torch.ops.vllm.start_prefetch(output, next_index)

            # No explicit offload needed - static buffers are reused implicitly

            # Restore hooked forward
            module.forward = forward
            return output

        module.forward = forward

    def _wait_for_layer(self, layer_idx: int):
        """Called by custom op - wait for copy to complete.

        Synchronization strategy:
        - During CUDA graph capture: use event-based wait (graph-compatible)
        - Outside capture (warmup/eager): use wait_stream (more robust)

        During capture, we skip wait for pre-capture prefetches because:
        1. sync_before_graph_capture() ensures pre-capture work is complete
        2. We can't wait on pre-capture events during capture (isolation error)
        """
        offloader = self.module_offloaders[layer_idx]

        if torch.cuda.is_current_stream_capturing():
            # During capture, skip wait for pre-capture prefetches.
            # sync_before_graph_capture() ensures pre-capture work is complete.
            if not offloader._prefetch_in_capture:
                return
            # Event-based wait for in-capture prefetches (graph-compatible)
            torch.cuda.current_stream().wait_event(offloader._copy_done_event)
            # Mark that this prefetch has been waited on (joined).
            offloader._prefetch_in_capture = False
        else:
            if offloader._event_valid_for_eager:
                # Use per-layer event to only wait for THIS layer's copy,
                # allowing other layers' prefetches to run concurrently.
                torch.cuda.current_stream().wait_event(offloader._copy_done_event)
            else:
                # Event not usable (unrecorded or recorded during capture).
                # Fall back to wait_stream to drain all copy_stream work.
                torch.cuda.current_stream().wait_stream(self.copy_stream)

    def sync_prev_onload(self):
        """Sync previous onload operations.

        Ensures any H2D copies in flight on copy_stream complete before
        the compute stream continues. Call this before CUDA graph
        capture/replay or when synchronization is needed.
        """
        torch.cuda.current_stream().wait_stream(self.copy_stream)

    def _start_prefetch(self, layer_idx: int):
        """Called by custom op - start async copy to static buffer."""
        offloader = self.module_offloaders[layer_idx]
        offloader.start_onload_to_static()

    def join_after_forward(self):
        """Join copy_stream after model forward completes.

        Call this after the model forward pass but before CUDA graph capture
        ends. This ensures copy_stream is rejoined for any prefetches started
        during the forward pass.

        We join ALL layers that have _prefetch_in_capture=True, meaning their
        prefetch was started during capture but not yet waited on (joined).
        This handles both full and piecewise cudagraph modes correctly:
        - Full mode: joins layers 0..prefetch_step-1 (prefetched by last layers)
        - Piecewise mode: joins only layers prefetched by THIS subgraph's layers
        """
        if not self.module_offloaders:
            return
        # Join all layers whose prefetch was started in capture but not waited on
        for offloader in self.module_offloaders:
            if offloader._prefetch_in_capture:
                torch.cuda.current_stream().wait_event(offloader._copy_done_event)
                offloader._prefetch_in_capture = False

    def post_init(self):
        """Allocate static buffer pool and start initial prefetches.

        Note: Parameters have already been offloaded to CPU during wrap_modules()
        (in _CpuParamOffloader.__init__), so GPU memory is available for the
        static buffer pool.
        """
        # Sync CPU storage with current param.data BEFORE collecting param info.
        # This is needed because process_weights_after_loading may have:
        # 1. Transformed weights (quantization, transpose, etc.)
        # 2. Created new CPU tensors via device_loading_context
        # Our _cpu_storage would be stale otherwise.
        for offloader in self.module_offloaders:
            offloader.sync_cpu_storage()

        # Collect parameter info (now using synced CPU storage)
        param_infos: list[ParamInfo] = []
        device: torch.device | None = None

        for offloader in self.module_offloaders:
            param_infos.extend(offloader.get_param_infos())
            if device is None:
                device = offloader.device

        if device is None:
            # No modules to offload
            return

        # Allocate static buffer pool
        self.buffer_pool = StaticBufferPool(
            param_infos=param_infos,
            slot_capacity=self.prefetch_step,
            device=device,
        )

        # Assign buffer slots and point parameters to GPU buffers
        for idx, offloader in enumerate(self.module_offloaders):
            slot_idx = idx % self.prefetch_step
            offloader.assign_buffer_slot(self.buffer_pool, slot_idx)

        # Collect offloaded bytes
        for offloader in self.module_offloaders:
            offloader.post_init()
            self.total_offloaded_bytes += offloader.offloaded_bytes

        logger.info_once(
            f"[PrefetchOffloader] Initialized {len(self.module_offloaders)} modules. "
            f"Total GPU memory saved: {self.total_offloaded_bytes / 1e9:.4f} GB, "
            f"Static buffer pool: {self.buffer_pool.total_bytes / 1e9:.4f} GB "
            f"(group_size={self.group_size}, num_in_group={self.num_in_group}, "
            f"prefetch_step={self.prefetch_step}, mode={self.mode})"
        )

        # Start initial prefetches
        for i in range(min(self.prefetch_step, len(self.module_offloaders))):
            self.module_offloaders[i].start_onload_to_static()

_hook_module_forward

_hook_module_forward(index: int, module: Module)

Hook module's forward with torch.compile-compatible sync.

Source code in vllm/model_executor/offloader/prefetch.py
def _hook_module_forward(self, index: int, module: nn.Module):
    """Hook module's forward with torch.compile-compatible sync."""
    original_forward = module.forward

    def forward(*args, **kwargs):
        # Temporarily restore original forward to avoid recursion
        module.forward = original_forward

        # Wait for this layer's prefetch to complete
        # mutates_args on input_tensor creates data dependency for torch.compile
        input_tensor = args[0] if args else kwargs.get("hidden_states")
        torch.ops.vllm.wait_prefetch(input_tensor, index)

        # No parameter swapping needed - parameters already point to
        # GPU static buffers (set in assign_static_buffer)
        output = original_forward(*args, **kwargs)

        # Start prefetch for next layer (circular)
        # mutates_args on output_tensor creates ordering dependency
        next_index = (index + self.prefetch_step) % len(self.module_offloaders)
        # Handle tuple output (e.g., (hidden_states, residual))
        if isinstance(output, tuple):
            torch.ops.vllm.start_prefetch(output[0], next_index)
        else:
            torch.ops.vllm.start_prefetch(output, next_index)

        # No explicit offload needed - static buffers are reused implicitly

        # Restore hooked forward
        module.forward = forward
        return output

    module.forward = forward

_start_prefetch

_start_prefetch(layer_idx: int)

Called by custom op - start async copy to static buffer.

Source code in vllm/model_executor/offloader/prefetch.py
def _start_prefetch(self, layer_idx: int):
    """Called by custom op - start async copy to static buffer."""
    offloader = self.module_offloaders[layer_idx]
    offloader.start_onload_to_static()

_wait_for_layer

_wait_for_layer(layer_idx: int)

Called by custom op - wait for copy to complete.

Synchronization strategy: - During CUDA graph capture: use event-based wait (graph-compatible) - Outside capture (warmup/eager): use wait_stream (more robust)

During capture, we skip wait for pre-capture prefetches because: 1. sync_before_graph_capture() ensures pre-capture work is complete 2. We can't wait on pre-capture events during capture (isolation error)

Source code in vllm/model_executor/offloader/prefetch.py
def _wait_for_layer(self, layer_idx: int):
    """Called by custom op - wait for copy to complete.

    Synchronization strategy:
    - During CUDA graph capture: use event-based wait (graph-compatible)
    - Outside capture (warmup/eager): use wait_stream (more robust)

    During capture, we skip wait for pre-capture prefetches because:
    1. sync_before_graph_capture() ensures pre-capture work is complete
    2. We can't wait on pre-capture events during capture (isolation error)
    """
    offloader = self.module_offloaders[layer_idx]

    if torch.cuda.is_current_stream_capturing():
        # During capture, skip wait for pre-capture prefetches.
        # sync_before_graph_capture() ensures pre-capture work is complete.
        if not offloader._prefetch_in_capture:
            return
        # Event-based wait for in-capture prefetches (graph-compatible)
        torch.cuda.current_stream().wait_event(offloader._copy_done_event)
        # Mark that this prefetch has been waited on (joined).
        offloader._prefetch_in_capture = False
    else:
        if offloader._event_valid_for_eager:
            # Use per-layer event to only wait for THIS layer's copy,
            # allowing other layers' prefetches to run concurrently.
            torch.cuda.current_stream().wait_event(offloader._copy_done_event)
        else:
            # Event not usable (unrecorded or recorded during capture).
            # Fall back to wait_stream to drain all copy_stream work.
            torch.cuda.current_stream().wait_stream(self.copy_stream)

join_after_forward

join_after_forward()

Join copy_stream after model forward completes.

Call this after the model forward pass but before CUDA graph capture ends. This ensures copy_stream is rejoined for any prefetches started during the forward pass.

We join ALL layers that have _prefetch_in_capture=True, meaning their prefetch was started during capture but not yet waited on (joined). This handles both full and piecewise cudagraph modes correctly: - Full mode: joins layers 0..prefetch_step-1 (prefetched by last layers) - Piecewise mode: joins only layers prefetched by THIS subgraph's layers

Source code in vllm/model_executor/offloader/prefetch.py
def join_after_forward(self):
    """Join copy_stream after model forward completes.

    Call this after the model forward pass but before CUDA graph capture
    ends. This ensures copy_stream is rejoined for any prefetches started
    during the forward pass.

    We join ALL layers that have _prefetch_in_capture=True, meaning their
    prefetch was started during capture but not yet waited on (joined).
    This handles both full and piecewise cudagraph modes correctly:
    - Full mode: joins layers 0..prefetch_step-1 (prefetched by last layers)
    - Piecewise mode: joins only layers prefetched by THIS subgraph's layers
    """
    if not self.module_offloaders:
        return
    # Join all layers whose prefetch was started in capture but not waited on
    for offloader in self.module_offloaders:
        if offloader._prefetch_in_capture:
            torch.cuda.current_stream().wait_event(offloader._copy_done_event)
            offloader._prefetch_in_capture = False

post_init

post_init()

Allocate static buffer pool and start initial prefetches.

Note: Parameters have already been offloaded to CPU during wrap_modules() (in _CpuParamOffloader.init), so GPU memory is available for the static buffer pool.

Source code in vllm/model_executor/offloader/prefetch.py
def post_init(self):
    """Allocate static buffer pool and start initial prefetches.

    Note: Parameters have already been offloaded to CPU during wrap_modules()
    (in _CpuParamOffloader.__init__), so GPU memory is available for the
    static buffer pool.
    """
    # Sync CPU storage with current param.data BEFORE collecting param info.
    # This is needed because process_weights_after_loading may have:
    # 1. Transformed weights (quantization, transpose, etc.)
    # 2. Created new CPU tensors via device_loading_context
    # Our _cpu_storage would be stale otherwise.
    for offloader in self.module_offloaders:
        offloader.sync_cpu_storage()

    # Collect parameter info (now using synced CPU storage)
    param_infos: list[ParamInfo] = []
    device: torch.device | None = None

    for offloader in self.module_offloaders:
        param_infos.extend(offloader.get_param_infos())
        if device is None:
            device = offloader.device

    if device is None:
        # No modules to offload
        return

    # Allocate static buffer pool
    self.buffer_pool = StaticBufferPool(
        param_infos=param_infos,
        slot_capacity=self.prefetch_step,
        device=device,
    )

    # Assign buffer slots and point parameters to GPU buffers
    for idx, offloader in enumerate(self.module_offloaders):
        slot_idx = idx % self.prefetch_step
        offloader.assign_buffer_slot(self.buffer_pool, slot_idx)

    # Collect offloaded bytes
    for offloader in self.module_offloaders:
        offloader.post_init()
        self.total_offloaded_bytes += offloader.offloaded_bytes

    logger.info_once(
        f"[PrefetchOffloader] Initialized {len(self.module_offloaders)} modules. "
        f"Total GPU memory saved: {self.total_offloaded_bytes / 1e9:.4f} GB, "
        f"Static buffer pool: {self.buffer_pool.total_bytes / 1e9:.4f} GB "
        f"(group_size={self.group_size}, num_in_group={self.num_in_group}, "
        f"prefetch_step={self.prefetch_step}, mode={self.mode})"
    )

    # Start initial prefetches
    for i in range(min(self.prefetch_step, len(self.module_offloaders))):
        self.module_offloaders[i].start_onload_to_static()

sync_prev_onload

sync_prev_onload()

Sync previous onload operations.

Ensures any H2D copies in flight on copy_stream complete before the compute stream continues. Call this before CUDA graph capture/replay or when synchronization is needed.

Source code in vllm/model_executor/offloader/prefetch.py
def sync_prev_onload(self):
    """Sync previous onload operations.

    Ensures any H2D copies in flight on copy_stream complete before
    the compute stream continues. Call this before CUDA graph
    capture/replay or when synchronization is needed.
    """
    torch.cuda.current_stream().wait_stream(self.copy_stream)

wrap_modules

wrap_modules(
    modules_generator: Generator[Module, None, None],
) -> list[Module]

Wrap modules with prefetch offloading logic.

Source code in vllm/model_executor/offloader/prefetch.py
def wrap_modules(
    self,
    modules_generator: Generator[nn.Module, None, None],
) -> list[nn.Module]:
    """Wrap modules with prefetch offloading logic."""
    assert len(self.module_offloaders) == 0, (
        "wrap_modules should only be called once"
    )

    all_modules = []
    offload_modules = []

    for module_index, module in enumerate(modules_generator):
        all_modules.append(module)

        # Select layers to offload based on group pattern
        # Offload last num_in_group layers of each group_size
        if module_index % self.group_size >= self.group_size - self.num_in_group:
            if self.offload_params:
                whitelist = [
                    name
                    for name, _ in module.named_parameters()
                    if any(f".{p}." in f".{name}." for p in self.offload_params)
                ]
            else:
                whitelist = [name for name, _ in module.named_parameters()]

            if not whitelist:
                continue  # skip layers with no matching params

            offload_modules.append(module)
            self.module_offloaders.append(
                _ModuleOffloader(
                    mode=self.mode,
                    module=module,
                    copy_stream=self.copy_stream,
                    whitelist_param_names=whitelist,
                    layer_idx=len(self.module_offloaders),
                )
            )

    for index, module in enumerate(offload_modules):
        self._hook_module_forward(index, module)

    return all_modules

UVAOffloader

Bases: BaseOffloader

Offloader using Unified Virtual Addressing (UVA) for zero-copy access.

This offloader moves parameters to pinned CPU memory and creates CUDA views using UVA. The GPU can then directly access the CPU memory without explicit transfers, at the cost of PCIe bandwidth (slower than GPU memory).

When UVA is disabled via env var, falls back to a functional_call-based approach that moves parameters on-demand.

Parameters:

Name Type Description Default
cpu_offload_max_bytes int

Maximum bytes to offload to CPU.

required
cpu_offload_params set[str] | None

Set of parameter name segments to selectively offload. If empty, all parameters are eligible up to the byte limit.

None
Source code in vllm/model_executor/offloader/uva.py
class UVAOffloader(BaseOffloader):
    """Offloader using Unified Virtual Addressing (UVA) for zero-copy access.

    This offloader moves parameters to pinned CPU memory and creates CUDA views
    using UVA. The GPU can then directly access the CPU memory without explicit
    transfers, at the cost of PCIe bandwidth (slower than GPU memory).

    When UVA is disabled via env var, falls back to a functional_call-based
    approach that moves parameters on-demand.

    Args:
        cpu_offload_max_bytes: Maximum bytes to offload to CPU.
        cpu_offload_params: Set of parameter name segments to selectively
            offload. If empty, all parameters are eligible up to the byte limit.
    """

    def __init__(
        self,
        cpu_offload_max_bytes: int,
        cpu_offload_params: set[str] | None = None,
    ):
        self.cpu_offload_max_bytes = cpu_offload_max_bytes
        self.cpu_offload_bytes = 0
        self.cpu_offload_params = cpu_offload_params or set()

        self.pin_memory = (
            is_pin_memory_available()
            and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
        )
        self.uva_offloading = (
            is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
        )

    def wrap_modules(
        self,
        modules_generator: Generator[nn.Module, None, None],
    ) -> list[nn.Module]:
        """Wrap modules with UVA offloading."""
        modules = [self._maybe_offload_to_cpu(module) for module in modules_generator]
        if self.cpu_offload_bytes > 0:
            logger.info(
                "Total CPU offloaded parameters: %s",
                format_gib(self.cpu_offload_bytes),
            )
        return modules

    def _maybe_offload_to_cpu(self, module: nn.Module) -> nn.Module:
        """Offload module parameters to CPU using UVA if budget allows."""
        if (params := next(module.parameters(), None)) is None:
            return module

        device = params.device

        if device == torch.device("cpu"):
            return module

        if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
            return module

        # offload parameters to CPU
        # use pin_memory if possible, which helps cudagraph capture speed
        offloaded_parameters = False
        for name, p in module.named_parameters():
            if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
                # we use per-parameter offloading
                # one module might have some parameters offloaded and some not
                break

            if self.cpu_offload_params:
                # Check if parameter belongs to the offloading set
                # Add dots here to ensure we match full segments only
                # e.g., "experts.w2_weight" matches "mlp.experts.w2_weight"
                # but not "mlp.experts.w2_weight_scale"
                should_offload = any(
                    f".{param}." in f".{name}." for param in self.cpu_offload_params
                )
                if not should_offload:
                    continue

            cpu_data = p.data.to(device="cpu")
            if self.pin_memory:
                cpu_data = cpu_data.pin_memory()

            if not self.uva_offloading:
                p.data = cpu_data
            else:
                p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
                p._vllm_is_uva_offloaded = True

            self.cpu_offload_bytes += p.data.numel() * p.data.element_size()
            offloaded_parameters = True

        if offloaded_parameters and not self.uva_offloading:
            original_forward = module.forward

            def forward(*args, **kwargs):
                module.forward = original_forward
                device_state = {
                    # here we blindly call `to(device)`
                    # if the parameter is already on the device,
                    # it will be a no-op
                    k: v.to(device, non_blocking=True)
                    for k, v in module.state_dict().items()
                }

                # set `tie_weights=False` as tied weights in original model
                # become untied when calling .to(device) individually
                output = functional_call(
                    module,
                    device_state,
                    args=args,
                    kwargs=kwargs,
                    tie_weights=False,
                )
                module.forward = forward
                return output

            module.forward = forward

        return module

_maybe_offload_to_cpu

_maybe_offload_to_cpu(module: Module) -> Module

Offload module parameters to CPU using UVA if budget allows.

Source code in vllm/model_executor/offloader/uva.py
def _maybe_offload_to_cpu(self, module: nn.Module) -> nn.Module:
    """Offload module parameters to CPU using UVA if budget allows."""
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device

    if device == torch.device("cpu"):
        return module

    if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
        return module

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
    offloaded_parameters = False
    for name, p in module.named_parameters():
        if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        if self.cpu_offload_params:
            # Check if parameter belongs to the offloading set
            # Add dots here to ensure we match full segments only
            # e.g., "experts.w2_weight" matches "mlp.experts.w2_weight"
            # but not "mlp.experts.w2_weight_scale"
            should_offload = any(
                f".{param}." in f".{name}." for param in self.cpu_offload_params
            )
            if not should_offload:
                continue

        cpu_data = p.data.to(device="cpu")
        if self.pin_memory:
            cpu_data = cpu_data.pin_memory()

        if not self.uva_offloading:
            p.data = cpu_data
        else:
            p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
            p._vllm_is_uva_offloaded = True

        self.cpu_offload_bytes += p.data.numel() * p.data.element_size()
        offloaded_parameters = True

    if offloaded_parameters and not self.uva_offloading:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device,
                # it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }

            # set `tie_weights=False` as tied weights in original model
            # become untied when calling .to(device) individually
            output = functional_call(
                module,
                device_state,
                args=args,
                kwargs=kwargs,
                tie_weights=False,
            )
            module.forward = forward
            return output

        module.forward = forward

    return module

wrap_modules

wrap_modules(
    modules_generator: Generator[Module, None, None],
) -> list[Module]

Wrap modules with UVA offloading.

Source code in vllm/model_executor/offloader/uva.py
def wrap_modules(
    self,
    modules_generator: Generator[nn.Module, None, None],
) -> list[nn.Module]:
    """Wrap modules with UVA offloading."""
    modules = [self._maybe_offload_to_cpu(module) for module in modules_generator]
    if self.cpu_offload_bytes > 0:
        logger.info(
            "Total CPU offloaded parameters: %s",
            format_gib(self.cpu_offload_bytes),
        )
    return modules

create_offloader

create_offloader(
    offload_config: OffloadConfig,
) -> BaseOffloader

Create an offloader based on the offload configuration.

Uses the explicit offload_backend selector. When set to "auto", selects prefetch if offload_group_size > 0, UVA if cpu_offload_gb > 0, otherwise noop.

Source code in vllm/model_executor/offloader/base.py
def create_offloader(offload_config: "OffloadConfig") -> BaseOffloader:
    """Create an offloader based on the offload configuration.

    Uses the explicit ``offload_backend`` selector.  When set to ``"auto"``,
    selects prefetch if ``offload_group_size > 0``, UVA if
    ``cpu_offload_gb > 0``, otherwise noop.
    """
    from vllm.model_executor.offloader.prefetch import PrefetchOffloader
    from vllm.model_executor.offloader.uva import UVAOffloader

    backend = offload_config.offload_backend
    uva = offload_config.uva
    prefetch = offload_config.prefetch

    if backend == "auto":
        if prefetch.offload_group_size > 0:
            backend = "prefetch"
        elif uva.cpu_offload_gb > 0:
            backend = "uva"
        else:
            return NoopOffloader()

    if backend == "prefetch":
        return PrefetchOffloader(
            group_size=prefetch.offload_group_size,
            num_in_group=prefetch.offload_num_in_group,
            prefetch_step=prefetch.offload_prefetch_step,
            offload_params=prefetch.offload_params,
            mode="cpu",
        )
    elif backend == "uva":
        return UVAOffloader(
            cpu_offload_max_bytes=int(uva.cpu_offload_gb * 1024**3),
            cpu_offload_params=uva.cpu_offload_params,
        )
    else:
        return NoopOffloader()

get_offloader

get_offloader() -> BaseOffloader

Get the global offloader instance.

Source code in vllm/model_executor/offloader/base.py
def get_offloader() -> BaseOffloader:
    """Get the global offloader instance."""
    return _instance

set_offloader

set_offloader(instance: BaseOffloader) -> None

Set the global offloader instance.

Source code in vllm/model_executor/offloader/base.py
def set_offloader(instance: BaseOffloader) -> None:
    """Set the global offloader instance."""
    global _instance
    _instance = instance
    logger.info("Offloader set to %s", type(instance).__name__)