Skip to content

vllm.device_allocator.cumem

HandleType module-attribute

HandleType = tuple[int, int, int, int]

cumem_available module-attribute

cumem_available = True

lib_name module-attribute

lib_name = find_loaded_library('cumem_allocator')

libcudart module-attribute

libcudart = CudaRTLibrary()

AllocationData dataclass

Source code in vllm/device_allocator/cumem.py
@dataclasses.dataclass
class AllocationData:
    handle: HandleType
    tag: str
    cpu_backup_tensor: Optional[torch.Tensor] = None

cpu_backup_tensor class-attribute instance-attribute

cpu_backup_tensor: Optional[Tensor] = None

handle instance-attribute

handle: HandleType

tag instance-attribute

tag: str

__init__

__init__(
    handle: HandleType,
    tag: str,
    cpu_backup_tensor: Optional[Tensor] = None,
) -> None

CuMemAllocator

A singleton class that manages a memory pool for CUDA tensors. The memory in this pool can be offloaded or discarded when the allocator sleeps.

Inside the use_memory_pool(tag) context, all tensors created will be allocated in the memory pool, and has the same tag as the tag passed to the context.

When we call sleep, all tensors with the specified tag will be offloaded to CPU memory, and the rest of the tensors will be discarded. When we call wake_up, all tensors that are previously offloaded will be loaded back to GPU memory, and the rest of the tensors will have empty memory.

Why it needs to be a singleton? When allocated tensors are garbage collected, PyTorch will call the free callback, which will call the python_free_callback method. The C-extension uses a global variable to store the function of an instance of this class. If we create multiple instances of this class, the global variable will be overwritten and the free callback will not work as expected.

Source code in vllm/device_allocator/cumem.py
class CuMemAllocator:
    """
    A singleton class that manages a memory pool for CUDA tensors.
    The memory in this pool can be offloaded or discarded when the
    allocator sleeps.

    Inside the `use_memory_pool(tag)` context, all tensors created will
    be allocated in the memory pool, and has the same tag as the
    tag passed to the context.

    When we call `sleep`, all tensors with the specified tag will be
    offloaded to CPU memory, and the rest of the tensors will be discarded.
    When we call `wake_up`, all tensors that are previously offloaded
    will be loaded back to GPU memory, and the rest of the tensors will
    have empty memory.

    Why it needs to be a singleton?
    When allocated tensors are garbage collected, PyTorch will call
    the free callback, which will call the `python_free_callback` method.
    The C-extension uses a global variable to store the function of an
    instance of this class. If we create multiple instances of this class,
    the global variable will be overwritten and the free callback will
    not work as expected.
    """
    instance: "CuMemAllocator" = None
    default_tag: str = "default"

    @staticmethod
    def get_instance() -> "CuMemAllocator":
        """
        CuMemAllocator is a singleton class.
        We cannot call the constructor directly.
        Call this method to get the instance.
        """
        assert cumem_available, "cumem allocator is not available"
        if CuMemAllocator.instance is None:
            CuMemAllocator.instance = CuMemAllocator()
        return CuMemAllocator.instance

    def __init__(self):
        conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
        assert "expandable_segments:True" not in conf, \
            ("Expandable segments are not compatible with memory pool. "
            "Please track https://github.com/pytorch/pytorch/issues/147851 "
            "for the latest updates.")

        self.pointer_to_data: dict[int, AllocationData] = {}
        self.current_tag: str = CuMemAllocator.default_tag
        self.allocator_and_pools: dict[str, Any] = {}

    def python_malloc_callback(self, allocation_handle: HandleType) -> None:
        """
        Internal method to store the allocation data
        when memory is allocated in the memory pool."""
        py_d_mem = allocation_handle[2]
        self.pointer_to_data[py_d_mem] = AllocationData(
            allocation_handle, self.current_tag)
        return

    def python_free_callback(self, ptr: int) -> HandleType:
        """
        Internal method to look up the allocation data
        when memory is freed in the memory pool."""
        data = self.pointer_to_data.pop(ptr)
        if data.cpu_backup_tensor is not None:
            data.cpu_backup_tensor = None
        return data.handle

    def sleep(
            self,
            offload_tags: Optional[Union[tuple[str, ...],
                                         str]] = None) -> None:
        """
        Put the allocator in sleep mode.
        All data in the memory allocation with the specified tag will be
        offloaded to CPU memory, and others will be discarded.

        :param offload_tags: The tags of the memory allocation that will be
            offloaded. The rest of the memory allocation will be discarded.
        """
        if offload_tags is None:
            # by default, allocated tensors are offloaded
            # when the allocator sleeps
            offload_tags = (CuMemAllocator.default_tag, )
        elif isinstance(offload_tags, str):
            offload_tags = (offload_tags, )

        assert isinstance(offload_tags, tuple)

        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
            if data.tag in offload_tags:
                size_in_bytes = handle[1]
                cpu_backup_tensor = torch.empty(
                    size_in_bytes,
                    dtype=torch.uint8,
                    device='cpu',
                    pin_memory=is_pin_memory_available())
                cpu_ptr = cpu_backup_tensor.data_ptr()
                libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
                data.cpu_backup_tensor = cpu_backup_tensor
            unmap_and_release(handle)

        gc.collect()
        torch.cuda.empty_cache()

    def wake_up(self, tags: Optional[list[str]] = None) -> None:
        """
        Wake up the allocator from sleep mode.
        All data that is previously offloaded will be loaded back to GPU 
        memory, and the rest of the data will have empty memory.

        :param tags: The tags of the memory allocation that will be loaded
            back to GPU memory. If None, all memory allocation will be loaded
            back to GPU memory.
        """
        for ptr, data in self.pointer_to_data.items():
            if tags is None or data.tag in tags:
                handle = data.handle
                create_and_map(handle)
                if data.cpu_backup_tensor is not None:
                    cpu_backup_tensor = data.cpu_backup_tensor
                    if cpu_backup_tensor is not None:
                        size_in_bytes = cpu_backup_tensor.numel(
                        ) * cpu_backup_tensor.element_size()
                        cpu_ptr = cpu_backup_tensor.data_ptr()
                        libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
                        data.cpu_backup_tensor = None

    @contextmanager
    def use_memory_pool(self, tag: Optional[str] = None):
        """
        A context manager to use the memory pool.
        All memory allocation created inside the context will be allocated
        in the memory pool, and has the specified tag.

        :param tag: The tag of the memory allocation. If None, the default tag
            will be used.
        """
        if tag is None:
            tag = CuMemAllocator.default_tag

        assert isinstance(tag, str)

        old_tag = self.current_tag
        self.current_tag = tag
        with use_memory_pool_with_allocator(self.python_malloc_callback,
                                            self.python_free_callback) as data:
            # start to hit another PyTorch bug in PyTorch 2.6,
            # possibly because of gc-related issue w.r.t. the allocator and
            # the memory pool.
            # to avoid the issue, we keep a reference of the data.
            # see https://github.com/pytorch/pytorch/issues/146431 .
            self.allocator_and_pools[tag] = data
            yield
            # PyTorch's bug, calling torch.cuda.empty_cache() will error
            # when using pluggable allocator, see
            # https://github.com/pytorch/pytorch/issues/145168 .
            # if we have some memory allocated and then freed,
            # the memory will not be released.
            # right now it is fine, because we only use this allocator
            # during weight loading and kv cache creation, where we only
            # allocate memory.
            # TODO: we need to find a way to release the memory,
            # i.e. calling torch.cuda.empty_cache()
            self.current_tag = old_tag

    def get_current_usage(self) -> int:
        """
        Get the total number of bytes allocated in the memory pool.
        """
        sum_bytes: int = 0
        for ptr, data in self.pointer_to_data.items():
            handle = data.handle
            sum_bytes += handle[1]
        return sum_bytes

allocator_and_pools instance-attribute

allocator_and_pools: dict[str, Any] = {}

current_tag instance-attribute

current_tag: str = default_tag

default_tag class-attribute instance-attribute

default_tag: str = 'default'

instance class-attribute instance-attribute

instance: CuMemAllocator = None

pointer_to_data instance-attribute

pointer_to_data: dict[int, AllocationData] = {}

__init__

__init__()
Source code in vllm/device_allocator/cumem.py
def __init__(self):
    conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
    assert "expandable_segments:True" not in conf, \
        ("Expandable segments are not compatible with memory pool. "
        "Please track https://github.com/pytorch/pytorch/issues/147851 "
        "for the latest updates.")

    self.pointer_to_data: dict[int, AllocationData] = {}
    self.current_tag: str = CuMemAllocator.default_tag
    self.allocator_and_pools: dict[str, Any] = {}

get_current_usage

get_current_usage() -> int

Get the total number of bytes allocated in the memory pool.

Source code in vllm/device_allocator/cumem.py
def get_current_usage(self) -> int:
    """
    Get the total number of bytes allocated in the memory pool.
    """
    sum_bytes: int = 0
    for ptr, data in self.pointer_to_data.items():
        handle = data.handle
        sum_bytes += handle[1]
    return sum_bytes

get_instance staticmethod

get_instance() -> CuMemAllocator

CuMemAllocator is a singleton class. We cannot call the constructor directly. Call this method to get the instance.

Source code in vllm/device_allocator/cumem.py
@staticmethod
def get_instance() -> "CuMemAllocator":
    """
    CuMemAllocator is a singleton class.
    We cannot call the constructor directly.
    Call this method to get the instance.
    """
    assert cumem_available, "cumem allocator is not available"
    if CuMemAllocator.instance is None:
        CuMemAllocator.instance = CuMemAllocator()
    return CuMemAllocator.instance

python_free_callback

python_free_callback(ptr: int) -> HandleType

Internal method to look up the allocation data when memory is freed in the memory pool.

Source code in vllm/device_allocator/cumem.py
def python_free_callback(self, ptr: int) -> HandleType:
    """
    Internal method to look up the allocation data
    when memory is freed in the memory pool."""
    data = self.pointer_to_data.pop(ptr)
    if data.cpu_backup_tensor is not None:
        data.cpu_backup_tensor = None
    return data.handle

python_malloc_callback

python_malloc_callback(
    allocation_handle: HandleType,
) -> None

Internal method to store the allocation data when memory is allocated in the memory pool.

Source code in vllm/device_allocator/cumem.py
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
    """
    Internal method to store the allocation data
    when memory is allocated in the memory pool."""
    py_d_mem = allocation_handle[2]
    self.pointer_to_data[py_d_mem] = AllocationData(
        allocation_handle, self.current_tag)
    return

sleep

sleep(
    offload_tags: Optional[
        Union[tuple[str, ...], str]
    ] = None,
) -> None

Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be offloaded to CPU memory, and others will be discarded.

:param offload_tags: The tags of the memory allocation that will be offloaded. The rest of the memory allocation will be discarded.

Source code in vllm/device_allocator/cumem.py
def sleep(
        self,
        offload_tags: Optional[Union[tuple[str, ...],
                                     str]] = None) -> None:
    """
    Put the allocator in sleep mode.
    All data in the memory allocation with the specified tag will be
    offloaded to CPU memory, and others will be discarded.

    :param offload_tags: The tags of the memory allocation that will be
        offloaded. The rest of the memory allocation will be discarded.
    """
    if offload_tags is None:
        # by default, allocated tensors are offloaded
        # when the allocator sleeps
        offload_tags = (CuMemAllocator.default_tag, )
    elif isinstance(offload_tags, str):
        offload_tags = (offload_tags, )

    assert isinstance(offload_tags, tuple)

    for ptr, data in self.pointer_to_data.items():
        handle = data.handle
        if data.tag in offload_tags:
            size_in_bytes = handle[1]
            cpu_backup_tensor = torch.empty(
                size_in_bytes,
                dtype=torch.uint8,
                device='cpu',
                pin_memory=is_pin_memory_available())
            cpu_ptr = cpu_backup_tensor.data_ptr()
            libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
            data.cpu_backup_tensor = cpu_backup_tensor
        unmap_and_release(handle)

    gc.collect()
    torch.cuda.empty_cache()

use_memory_pool

use_memory_pool(tag: Optional[str] = None)

A context manager to use the memory pool. All memory allocation created inside the context will be allocated in the memory pool, and has the specified tag.

:param tag: The tag of the memory allocation. If None, the default tag will be used.

Source code in vllm/device_allocator/cumem.py
@contextmanager
def use_memory_pool(self, tag: Optional[str] = None):
    """
    A context manager to use the memory pool.
    All memory allocation created inside the context will be allocated
    in the memory pool, and has the specified tag.

    :param tag: The tag of the memory allocation. If None, the default tag
        will be used.
    """
    if tag is None:
        tag = CuMemAllocator.default_tag

    assert isinstance(tag, str)

    old_tag = self.current_tag
    self.current_tag = tag
    with use_memory_pool_with_allocator(self.python_malloc_callback,
                                        self.python_free_callback) as data:
        # start to hit another PyTorch bug in PyTorch 2.6,
        # possibly because of gc-related issue w.r.t. the allocator and
        # the memory pool.
        # to avoid the issue, we keep a reference of the data.
        # see https://github.com/pytorch/pytorch/issues/146431 .
        self.allocator_and_pools[tag] = data
        yield
        # PyTorch's bug, calling torch.cuda.empty_cache() will error
        # when using pluggable allocator, see
        # https://github.com/pytorch/pytorch/issues/145168 .
        # if we have some memory allocated and then freed,
        # the memory will not be released.
        # right now it is fine, because we only use this allocator
        # during weight loading and kv cache creation, where we only
        # allocate memory.
        # TODO: we need to find a way to release the memory,
        # i.e. calling torch.cuda.empty_cache()
        self.current_tag = old_tag

wake_up

wake_up(tags: Optional[list[str]] = None) -> None

Wake up the allocator from sleep mode. All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory.

:param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory.

Source code in vllm/device_allocator/cumem.py
def wake_up(self, tags: Optional[list[str]] = None) -> None:
    """
    Wake up the allocator from sleep mode.
    All data that is previously offloaded will be loaded back to GPU 
    memory, and the rest of the data will have empty memory.

    :param tags: The tags of the memory allocation that will be loaded
        back to GPU memory. If None, all memory allocation will be loaded
        back to GPU memory.
    """
    for ptr, data in self.pointer_to_data.items():
        if tags is None or data.tag in tags:
            handle = data.handle
            create_and_map(handle)
            if data.cpu_backup_tensor is not None:
                cpu_backup_tensor = data.cpu_backup_tensor
                if cpu_backup_tensor is not None:
                    size_in_bytes = cpu_backup_tensor.numel(
                    ) * cpu_backup_tensor.element_size()
                    cpu_ptr = cpu_backup_tensor.data_ptr()
                    libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
                    data.cpu_backup_tensor = None

create_and_map

create_and_map(allocation_handle: HandleType) -> None
Source code in vllm/device_allocator/cumem.py
def create_and_map(allocation_handle: HandleType) -> None:
    python_create_and_map(*allocation_handle)

find_loaded_library

find_loaded_library(lib_name) -> Optional[str]

According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, the file /proc/self/maps contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library.

Source code in vllm/device_allocator/cumem.py
def find_loaded_library(lib_name) -> Optional[str]:
    """
    According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
    the file `/proc/self/maps` contains the memory maps of the process, which includes the
    shared libraries loaded by the process. We can use this file to find the path of the
    a loaded library.
    """ # noqa
    found_line = None
    with open("/proc/self/maps") as f:
        for line in f:
            if lib_name in line:
                found_line = line
                break
    if found_line is None:
        # the library is not loaded in the current process
        return None
    # if lib_name is libcudart, we need to match a line with:
    # address /path/to/libcudart-hash.so.11.0
    start = found_line.index("/")
    path = found_line[start:].strip()
    filename = path.split("/")[-1]
    assert filename.rpartition(".so")[0].startswith(lib_name), \
        f"Unexpected filename: {filename} for library {lib_name}"
    return path

get_pluggable_allocator

get_pluggable_allocator(
    python_malloc_fn: Callable[[int], int],
    python_free_func: Callable[[int, int], None],
) -> CUDAPluggableAllocator
Source code in vllm/device_allocator/cumem.py
def get_pluggable_allocator(
    python_malloc_fn: Callable[[int],
                               int], python_free_func: Callable[[int, int],
                                                                None]
) -> torch.cuda.memory.CUDAPluggableAllocator:
    init_module(python_malloc_fn, python_free_func)
    new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
        lib_name, 'my_malloc', 'my_free')
    return new_alloc

unmap_and_release

unmap_and_release(allocation_handle: HandleType) -> None
Source code in vllm/device_allocator/cumem.py
def unmap_and_release(allocation_handle: HandleType) -> None:
    python_unmap_and_release(*allocation_handle)

use_memory_pool_with_allocator

use_memory_pool_with_allocator(
    python_malloc_fn: Callable[[int], int],
    python_free_func: Callable[[int, int], None],
) -> None
Source code in vllm/device_allocator/cumem.py
@contextmanager
def use_memory_pool_with_allocator(
        python_malloc_fn: Callable[[int], int],
        python_free_func: Callable[[int, int], None]) -> None:
    new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
    mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
    with torch.cuda.memory.use_mem_pool(mem_pool):
        yield mem_pool, new_alloc