Skip to content

vllm.distributed.device_communicators.base_device_communicator

All2AllManagerBase

Source code in vllm/distributed/device_communicators/base_device_communicator.py
class All2AllManagerBase:

    def __init__(self, cpu_group):
        self.cpu_group = cpu_group

        # compute some common properties
        from vllm.distributed.parallel_state import (get_dp_group,
                                                     get_tp_group,
                                                     in_the_same_node_as)

        # all2all lives in ep group, which is merged from dp and tp group
        self.dp_group = get_dp_group()
        self.tp_group = get_tp_group()
        # no self.ep_group since self.ep_group is still in construction
        # when we create this object
        self.dp_rank = self.dp_group.rank_in_group
        self.dp_world_size = self.dp_group.world_size
        self.rank = dist.get_rank(cpu_group)
        self.world_size = dist.get_world_size(cpu_group)

        # all2all communication often has separate implementations for
        # intra-node and inter-node communication
        self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))

    def get_handle(self, kwargs):
        # get a handle for the all2all communication,
        # based on the kwargs.
        # different layers can have different configs,
        # e.g. one layer has hidden size 1024, another has 2048.
        # usually the underlying implementation caches the handle
        # and reuse it for the same config.
        raise NotImplementedError

    def dispatch(self, hidden_states: torch.Tensor,
                 router_logits: torch.Tensor):
        raise NotImplementedError

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def destroy(self):
        pass

cpu_group instance-attribute

cpu_group = cpu_group

dp_group instance-attribute

dp_group = get_dp_group()

dp_rank instance-attribute

dp_rank = rank_in_group

dp_world_size instance-attribute

dp_world_size = world_size

internode instance-attribute

internode = not all(
    in_the_same_node_as(cpu_group, source_rank=0)
)

rank instance-attribute

rank = get_rank(cpu_group)

tp_group instance-attribute

tp_group = get_tp_group()

world_size instance-attribute

world_size = get_world_size(cpu_group)

__init__

__init__(cpu_group)
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def __init__(self, cpu_group):
    self.cpu_group = cpu_group

    # compute some common properties
    from vllm.distributed.parallel_state import (get_dp_group,
                                                 get_tp_group,
                                                 in_the_same_node_as)

    # all2all lives in ep group, which is merged from dp and tp group
    self.dp_group = get_dp_group()
    self.tp_group = get_tp_group()
    # no self.ep_group since self.ep_group is still in construction
    # when we create this object
    self.dp_rank = self.dp_group.rank_in_group
    self.dp_world_size = self.dp_group.world_size
    self.rank = dist.get_rank(cpu_group)
    self.world_size = dist.get_world_size(cpu_group)

    # all2all communication often has separate implementations for
    # intra-node and inter-node communication
    self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))

combine

combine(hidden_states: Tensor) -> Tensor
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
    raise NotImplementedError

destroy

destroy()
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def destroy(self):
    pass

dispatch

dispatch(hidden_states: Tensor, router_logits: Tensor)
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def dispatch(self, hidden_states: torch.Tensor,
             router_logits: torch.Tensor):
    raise NotImplementedError

get_handle

get_handle(kwargs)
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def get_handle(self, kwargs):
    # get a handle for the all2all communication,
    # based on the kwargs.
    # different layers can have different configs,
    # e.g. one layer has hidden size 1024, another has 2048.
    # usually the underlying implementation caches the handle
    # and reuse it for the same config.
    raise NotImplementedError

Cache

Source code in vllm/distributed/device_communicators/base_device_communicator.py
class Cache:

    def __init__(self):
        self._cache: WeakValueDictionary = WeakValueDictionary()
        self._lock = threading.RLock()  # Reentrant lock for thread safety

    def get_or_create(self, kwargs, func):
        # Create a hashable key from the kwargs
        key = tuple(sorted((k, v) for k, v in kwargs.items()))

        with self._lock:
            instance = self._cache.get(key)
            if instance is None:
                instance = func(**kwargs)
                self._cache[key] = instance
            return instance

_cache instance-attribute

_lock instance-attribute

_lock = RLock()

__init__

__init__()
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def __init__(self):
    self._cache: WeakValueDictionary = WeakValueDictionary()
    self._lock = threading.RLock()  # Reentrant lock for thread safety

get_or_create

get_or_create(kwargs, func)
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def get_or_create(self, kwargs, func):
    # Create a hashable key from the kwargs
    key = tuple(sorted((k, v) for k, v in kwargs.items()))

    with self._lock:
        instance = self._cache.get(key)
        if instance is None:
            instance = func(**kwargs)
            self._cache[key] = instance
        return instance

DeviceCommunicatorBase

Base class for device-specific communicator. It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
class DeviceCommunicatorBase:
    """
    Base class for device-specific communicator.
    It can use the `cpu_group` to initialize the communicator.
    If the device has PyTorch integration (PyTorch can recognize its
    communication backend), the `device_group` will also be given.
    """

    def __init__(self,
                 cpu_group: ProcessGroup,
                 device: Optional[torch.device] = None,
                 device_group: Optional[ProcessGroup] = None,
                 unique_name: str = ""):
        self.device = device or torch.device("cpu")
        self.cpu_group = cpu_group
        self.device_group = device_group
        self.unique_name = unique_name
        self.rank = dist.get_rank(cpu_group)
        self.world_size = dist.get_world_size(cpu_group)
        self.ranks = dist.get_process_group_ranks(cpu_group)
        self.global_rank = dist.get_rank()
        self.global_world_size = dist.get_world_size()
        self.rank_in_group = dist.get_group_rank(self.cpu_group,
                                                 self.global_rank)

        use_ep = False
        from vllm.config import get_current_vllm_config
        config = get_current_vllm_config()
        if config is not None:
            # as long as we use data parallel (coupled data parallel
            # where all data parallel ranks execute forward together),
            # we initialize the all2all manager used in expert parallel.
            use_ep = config.parallel_config.data_parallel_size > 1

        self.use_all2all = "ep" in unique_name and use_ep
        self.all2all_manager: Optional[All2AllManagerBase] = None

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        dist.all_reduce(input_, group=self.device_group)
        return input_

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
        # NOTE: we have to use concat-style all-gather here,
        # stack-style all-gather has compatibility issues with
        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
        output_size = (input_size[0] * self.world_size, ) + input_size[1:]
        # Allocate output tensor.
        output_tensor = torch.empty(output_size,
                                    dtype=input_.dtype,
                                    device=input_.device)
        # All-gather.
        dist.all_gather_into_tensor(output_tensor,
                                    input_,
                                    group=self.device_group)
        # Reshape
        output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(input_size[:dim] +
                                              (self.world_size *
                                               input_size[dim], ) +
                                              input_size[dim + 1:])
        return output_tensor

    def reduce_scatter(self,
                       input_: torch.Tensor,
                       dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Note: This will produce an incorrect answer if we don't make
        # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
        input_tensor = input_.movedim(0, dim).contiguous()

        assert input_tensor.shape[0] % world_size == 0
        chunk_size = input_tensor.shape[0] // world_size
        output_shape = (chunk_size, ) + input_tensor.shape[1:]

        output_tensor = torch.empty(output_shape,
                                    dtype=input_tensor.dtype,
                                    device=input_tensor.device)

        # Perform reduce-scatter operation
        torch.distributed.reduce_scatter_tensor(output_tensor,
                                                input_tensor,
                                                group=self.device_group)

        # Reshape before returning
        return output_tensor.movedim(0, dim).contiguous()

    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
               dim: int = -1) -> Optional[torch.Tensor]:
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Allocate output tensor.
        if self.rank_in_group == dst:
            gather_list = [torch.empty_like(input_) for _ in range(world_size)]
        else:
            gather_list = None
        # Gather.
        torch.distributed.gather(input_,
                                 gather_list,
                                 dst=self.ranks[dst],
                                 group=self.device_group)
        if self.rank_in_group == dst:
            output_tensor = torch.cat(gather_list, dim=dim)
        else:
            output_tensor = None
        return output_tensor

    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
        if dst is None:
            dst = (self.rank_in_group + 1) % self.world_size
        torch.distributed.send(tensor, self.ranks[dst], self.device_group)

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
        if src is None:
            src = (self.rank_in_group - 1) % self.world_size

        tensor = torch.empty(size, dtype=dtype, device=self.device)
        torch.distributed.recv(tensor, self.ranks[src], self.device_group)
        return tensor

    def destroy(self):
        pass

    def prepare_communication_buffer_for_model(self,
                                               model: torch.nn.Module) -> None:
        """
        Prepare the communication buffer for the model.
        """
        if not self.use_all2all:
            return

        moe_modules = [
            module for module in model.modules()
            if module.__class__.__name__ == "FusedMoE"
        ]
        for module in moe_modules:
            module.quant_method.init_prepare_finalize(module.moe_config,
                                                      module.quant_config)

    def dispatch(
            self, hidden_states: torch.Tensor,
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Dispatch the hidden states and router logits to the appropriate device.
        This is a no-op in the base class.
        """
        return hidden_states, router_logits

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Combine the hidden states and router logits from the appropriate device.
        This is a no-op in the base class.
        """
        return hidden_states

all2all_manager instance-attribute

all2all_manager: Optional[All2AllManagerBase] = None

cpu_group instance-attribute

cpu_group = cpu_group

device instance-attribute

device = device or device('cpu')

device_group instance-attribute

device_group = device_group

global_rank instance-attribute

global_rank = get_rank()

global_world_size instance-attribute

global_world_size = get_world_size()

rank instance-attribute

rank = get_rank(cpu_group)

rank_in_group instance-attribute

rank_in_group = get_group_rank(cpu_group, global_rank)

ranks instance-attribute

ranks = get_process_group_ranks(cpu_group)

unique_name instance-attribute

unique_name = unique_name

use_all2all instance-attribute

use_all2all = 'ep' in unique_name and use_ep

world_size instance-attribute

world_size = get_world_size(cpu_group)

__init__

__init__(
    cpu_group: ProcessGroup,
    device: Optional[device] = None,
    device_group: Optional[ProcessGroup] = None,
    unique_name: str = "",
)
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: Optional[torch.device] = None,
             device_group: Optional[ProcessGroup] = None,
             unique_name: str = ""):
    self.device = device or torch.device("cpu")
    self.cpu_group = cpu_group
    self.device_group = device_group
    self.unique_name = unique_name
    self.rank = dist.get_rank(cpu_group)
    self.world_size = dist.get_world_size(cpu_group)
    self.ranks = dist.get_process_group_ranks(cpu_group)
    self.global_rank = dist.get_rank()
    self.global_world_size = dist.get_world_size()
    self.rank_in_group = dist.get_group_rank(self.cpu_group,
                                             self.global_rank)

    use_ep = False
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None:
        # as long as we use data parallel (coupled data parallel
        # where all data parallel ranks execute forward together),
        # we initialize the all2all manager used in expert parallel.
        use_ep = config.parallel_config.data_parallel_size > 1

    self.use_all2all = "ep" in unique_name and use_ep
    self.all2all_manager: Optional[All2AllManagerBase] = None

all_gather

all_gather(input_: Tensor, dim: int = -1) -> Tensor
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()
    input_size = input_.size()
    # NOTE: we have to use concat-style all-gather here,
    # stack-style all-gather has compatibility issues with
    # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
    output_size = (input_size[0] * self.world_size, ) + input_size[1:]
    # Allocate output tensor.
    output_tensor = torch.empty(output_size,
                                dtype=input_.dtype,
                                device=input_.device)
    # All-gather.
    dist.all_gather_into_tensor(output_tensor,
                                input_,
                                group=self.device_group)
    # Reshape
    output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
    output_tensor = output_tensor.movedim(0, dim)
    output_tensor = output_tensor.reshape(input_size[:dim] +
                                          (self.world_size *
                                           input_size[dim], ) +
                                          input_size[dim + 1:])
    return output_tensor

all_reduce

all_reduce(input_: Tensor) -> Tensor
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
    dist.all_reduce(input_, group=self.device_group)
    return input_

combine

combine(hidden_states: Tensor) -> Tensor

Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
    """
    Combine the hidden states and router logits from the appropriate device.
    This is a no-op in the base class.
    """
    return hidden_states

destroy

destroy()
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def destroy(self):
    pass

dispatch

dispatch(
    hidden_states: Tensor, router_logits: Tensor
) -> tuple[Tensor, Tensor]

Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def dispatch(
        self, hidden_states: torch.Tensor,
        router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Dispatch the hidden states and router logits to the appropriate device.
    This is a no-op in the base class.
    """
    return hidden_states, router_logits

gather

gather(
    input_: Tensor, dst: int = 0, dim: int = -1
) -> Optional[Tensor]

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def gather(self,
           input_: torch.Tensor,
           dst: int = 0,
           dim: int = -1) -> Optional[torch.Tensor]:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(input_,
                             gather_list,
                             dst=self.ranks[dst],
                             group=self.device_group)
    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor

prepare_communication_buffer_for_model

prepare_communication_buffer_for_model(
    model: Module,
) -> None

Prepare the communication buffer for the model.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def prepare_communication_buffer_for_model(self,
                                           model: torch.nn.Module) -> None:
    """
    Prepare the communication buffer for the model.
    """
    if not self.use_all2all:
        return

    moe_modules = [
        module for module in model.modules()
        if module.__class__.__name__ == "FusedMoE"
    ]
    for module in moe_modules:
        module.quant_method.init_prepare_finalize(module.moe_config,
                                                  module.quant_config)

recv

recv(
    size: Size, dtype: dtype, src: Optional[int] = None
) -> Tensor

Receives a tensor from the source rank.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: Optional[int] = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor

reduce_scatter

reduce_scatter(input_: Tensor, dim: int = -1) -> Tensor
Source code in vllm/distributed/device_communicators/base_device_communicator.py
def reduce_scatter(self,
                   input_: torch.Tensor,
                   dim: int = -1) -> torch.Tensor:
    world_size = self.world_size
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Note: This will produce an incorrect answer if we don't make
    # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
    input_tensor = input_.movedim(0, dim).contiguous()

    assert input_tensor.shape[0] % world_size == 0
    chunk_size = input_tensor.shape[0] // world_size
    output_shape = (chunk_size, ) + input_tensor.shape[1:]

    output_tensor = torch.empty(output_shape,
                                dtype=input_tensor.dtype,
                                device=input_tensor.device)

    # Perform reduce-scatter operation
    torch.distributed.reduce_scatter_tensor(output_tensor,
                                            input_tensor,
                                            group=self.device_group)

    # Reshape before returning
    return output_tensor.movedim(0, dim).contiguous()

send

send(tensor: Tensor, dst: Optional[int] = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size
    torch.distributed.send(tensor, self.ranks[dst], self.device_group)