Skip to content

vllm.model_executor.offloader.uva

UVA-based CPU offloading using Unified Virtual Addressing.

UVAOffloader

Bases: BaseOffloader

Offloader using Unified Virtual Addressing (UVA) for zero-copy access.

This offloader moves parameters to pinned CPU memory and creates CUDA views using UVA. The GPU can then directly access the CPU memory without explicit transfers, at the cost of PCIe bandwidth (slower than GPU memory).

When UVA is disabled via env var, falls back to a functional_call-based approach that moves parameters on-demand.

Parameters:

Name Type Description Default
cpu_offload_max_bytes int

Maximum bytes to offload to CPU.

required
cpu_offload_params set[str] | None

Set of parameter name segments to selectively offload. If empty, all parameters are eligible up to the byte limit.

None
Source code in vllm/model_executor/offloader/uva.py
class UVAOffloader(BaseOffloader):
    """Offloader using Unified Virtual Addressing (UVA) for zero-copy access.

    This offloader moves parameters to pinned CPU memory and creates CUDA views
    using UVA. The GPU can then directly access the CPU memory without explicit
    transfers, at the cost of PCIe bandwidth (slower than GPU memory).

    When UVA is disabled via env var, falls back to a functional_call-based
    approach that moves parameters on-demand.

    Args:
        cpu_offload_max_bytes: Maximum bytes to offload to CPU.
        cpu_offload_params: Set of parameter name segments to selectively
            offload. If empty, all parameters are eligible up to the byte limit.
    """

    def __init__(
        self,
        cpu_offload_max_bytes: int,
        cpu_offload_params: set[str] | None = None,
    ):
        self.cpu_offload_max_bytes = cpu_offload_max_bytes
        self.cpu_offload_bytes = 0
        self.cpu_offload_params = cpu_offload_params or set()

        self.pin_memory = (
            is_pin_memory_available()
            and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
        )
        self.uva_offloading = (
            is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
        )

    def wrap_modules(
        self,
        modules_generator: Generator[nn.Module, None, None],
    ) -> list[nn.Module]:
        """Wrap modules with UVA offloading."""
        modules = [self._maybe_offload_to_cpu(module) for module in modules_generator]
        if self.cpu_offload_bytes > 0:
            logger.info(
                "Total CPU offloaded parameters: %s",
                format_gib(self.cpu_offload_bytes),
            )
        return modules

    def _maybe_offload_to_cpu(self, module: nn.Module) -> nn.Module:
        """Offload module parameters to CPU using UVA if budget allows."""
        if (params := next(module.parameters(), None)) is None:
            return module

        device = params.device

        if device == torch.device("cpu"):
            return module

        if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
            return module

        # offload parameters to CPU
        # use pin_memory if possible, which helps cudagraph capture speed
        offloaded_parameters = False
        for name, p in module.named_parameters():
            if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
                # we use per-parameter offloading
                # one module might have some parameters offloaded and some not
                break

            if self.cpu_offload_params:
                # Check if parameter belongs to the offloading set
                # Add dots here to ensure we match full segments only
                # e.g., "experts.w2_weight" matches "mlp.experts.w2_weight"
                # but not "mlp.experts.w2_weight_scale"
                should_offload = any(
                    f".{param}." in f".{name}." for param in self.cpu_offload_params
                )
                if not should_offload:
                    continue

            cpu_data = p.data.to(device="cpu")
            if self.pin_memory:
                cpu_data = cpu_data.pin_memory()

            if not self.uva_offloading:
                p.data = cpu_data
            else:
                p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
                p._vllm_is_uva_offloaded = True

            self.cpu_offload_bytes += p.data.numel() * p.data.element_size()
            offloaded_parameters = True

        if offloaded_parameters and not self.uva_offloading:
            original_forward = module.forward

            def forward(*args, **kwargs):
                module.forward = original_forward
                device_state = {
                    # here we blindly call `to(device)`
                    # if the parameter is already on the device,
                    # it will be a no-op
                    k: v.to(device, non_blocking=True)
                    for k, v in module.state_dict().items()
                }

                # set `tie_weights=False` as tied weights in original model
                # become untied when calling .to(device) individually
                output = functional_call(
                    module,
                    device_state,
                    args=args,
                    kwargs=kwargs,
                    tie_weights=False,
                )
                module.forward = forward
                return output

            module.forward = forward

        return module

_maybe_offload_to_cpu

_maybe_offload_to_cpu(module: Module) -> Module

Offload module parameters to CPU using UVA if budget allows.

Source code in vllm/model_executor/offloader/uva.py
def _maybe_offload_to_cpu(self, module: nn.Module) -> nn.Module:
    """Offload module parameters to CPU using UVA if budget allows."""
    if (params := next(module.parameters(), None)) is None:
        return module

    device = params.device

    if device == torch.device("cpu"):
        return module

    if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
        return module

    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
    offloaded_parameters = False
    for name, p in module.named_parameters():
        if self.cpu_offload_bytes >= self.cpu_offload_max_bytes:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        if self.cpu_offload_params:
            # Check if parameter belongs to the offloading set
            # Add dots here to ensure we match full segments only
            # e.g., "experts.w2_weight" matches "mlp.experts.w2_weight"
            # but not "mlp.experts.w2_weight_scale"
            should_offload = any(
                f".{param}." in f".{name}." for param in self.cpu_offload_params
            )
            if not should_offload:
                continue

        cpu_data = p.data.to(device="cpu")
        if self.pin_memory:
            cpu_data = cpu_data.pin_memory()

        if not self.uva_offloading:
            p.data = cpu_data
        else:
            p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
            p._vllm_is_uva_offloaded = True

        self.cpu_offload_bytes += p.data.numel() * p.data.element_size()
        offloaded_parameters = True

    if offloaded_parameters and not self.uva_offloading:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device,
                # it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }

            # set `tie_weights=False` as tied weights in original model
            # become untied when calling .to(device) individually
            output = functional_call(
                module,
                device_state,
                args=args,
                kwargs=kwargs,
                tie_weights=False,
            )
            module.forward = forward
            return output

        module.forward = forward

    return module

wrap_modules

wrap_modules(
    modules_generator: Generator[Module, None, None],
) -> list[Module]

Wrap modules with UVA offloading.

Source code in vllm/model_executor/offloader/uva.py
def wrap_modules(
    self,
    modules_generator: Generator[nn.Module, None, None],
) -> list[nn.Module]:
    """Wrap modules with UVA offloading."""
    modules = [self._maybe_offload_to_cpu(module) for module in modules_generator]
    if self.cpu_offload_bytes > 0:
        logger.info(
            "Total CPU offloaded parameters: %s",
            format_gib(self.cpu_offload_bytes),
        )
    return modules