Skip to content

vllm.distributed.device_communicators.cuda_communicator

logger module-attribute

logger = init_logger(__name__)

CudaCommunicator

Bases: DeviceCommunicatorBase

Source code in vllm/distributed/device_communicators/cuda_communicator.py
class CudaCommunicator(DeviceCommunicatorBase):

    def __init__(self,
                 cpu_group: ProcessGroup,
                 device: Optional[torch.device] = None,
                 device_group: Optional[ProcessGroup] = None,
                 unique_name: str = ""):
        super().__init__(cpu_group, device, device_group, unique_name)
        if "tp" not in unique_name:
            # only tp uses custom allreduce
            use_custom_allreduce = False
        else:
            from vllm.distributed.parallel_state import (
                _ENABLE_CUSTOM_ALL_REDUCE)
            use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE

        # ep does not use pynccl
        use_pynccl = "ep" not in unique_name

        self.use_pynccl = use_pynccl
        self.use_custom_allreduce = use_custom_allreduce

        # lazy import to avoid documentation build error
        from vllm.distributed.device_communicators.custom_all_reduce import (
            CustomAllreduce)
        from vllm.distributed.device_communicators.pynccl import (
            PyNcclCommunicator)
        from vllm.distributed.device_communicators.quick_all_reduce import (
            QuickAllReduce)

        self.pynccl_comm: Optional[PyNcclCommunicator] = None
        if use_pynccl and self.world_size > 1:
            self.pynccl_comm = PyNcclCommunicator(
                group=self.cpu_group,
                device=self.device,
            )

        self.ca_comm: Optional[CustomAllreduce] = None
        self.qr_comm: Optional[QuickAllReduce] = None
        if use_custom_allreduce and self.world_size > 1:
            # Initialize a custom fast all-reduce implementation.
            self.ca_comm = CustomAllreduce(
                group=self.cpu_group,
                device=self.device,
            )

            if current_platform.is_rocm():
                # Initialize a custom quick all-reduce implementation for AMD.
                # Quick reduce is designed as a complement to custom allreduce.
                # Based on quickreduce (https://github.com/mk1-project/quickreduce).
                # If it's a rocm, 'use_custom_allreduce==True' means it must
                # currently be an MI300 series.
                self.qr_comm = QuickAllReduce(group=self.cpu_group,
                                              device=self.device)
        if self.use_all2all:
            all2all_backend = envs.VLLM_ALL2ALL_BACKEND
            if all2all_backend == "naive":
                from .all2all import NaiveAll2AllManager
                self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
                logger.info("Using naive all2all manager.")
            elif all2all_backend == "pplx":
                from .all2all import PPLXAll2AllManager
                self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
                logger.info("Using PPLX all2all manager.")
            elif all2all_backend == "deepep_high_throughput":
                from .all2all import DeepEPHTAll2AllManager
                self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
                logger.info("Using DeepEP High-Throughput all2all manager.")
            elif all2all_backend == "deepep_low_latency":
                from .all2all import DeepEPLLAll2AllManager
                self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
                logger.info("Using DeepEP Low-Latency all2all manager.")
            else:
                raise ValueError(f"Unknown all2all backend: {all2all_backend}")

    def all_reduce(self, input_):
        # always try quick reduce first, then custom allreduce,
        # and then pynccl. (quick reduce just for ROCM MI3*)
        qr_comm = self.qr_comm
        if qr_comm is not None and not qr_comm.disabled and \
            qr_comm.should_quick_allreduce(input_):
            out = qr_comm.quick_all_reduce(input_)
            assert out is not None
            return out
        ca_comm = self.ca_comm
        if ca_comm is not None and not ca_comm.disabled and \
            ca_comm.should_custom_ar(input_):
            out = ca_comm.custom_all_reduce(input_)
            assert out is not None
            return out
        pynccl_comm = self.pynccl_comm
        assert pynccl_comm is not None
        out = pynccl_comm.all_reduce(input_)
        if out is None:
            # fall back to the default all-reduce using PyTorch.
            # this usually happens during testing.
            # when we run the model, allreduce only happens for the TP
            # group, where we always have either custom allreduce or pynccl.
            out = input_.clone()
            torch.distributed.all_reduce(out, group=self.device_group)
        return out

    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
        world_size = self.world_size
        pynccl_comm = self.pynccl_comm
        assert pynccl_comm is not None
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Note: This will produce an incorrect answer if we don't make
        # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
        input_tensor = input_.movedim(0, dim).contiguous()

        assert input_tensor.shape[0] % world_size == 0
        chunk_size = input_tensor.shape[0] // world_size
        output_shape = (chunk_size, ) + input_tensor.shape[1:]

        output = torch.empty(output_shape,
                             dtype=input_tensor.dtype,
                             device=input_tensor.device)

        pynccl_comm.reduce_scatter(output, input_)

        # Reshape before returning
        return output.movedim(0, dim).contiguous()

    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
        if dst is None:
            dst = (self.rank_in_group + 1) % self.world_size

        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.send(tensor, dst)
        else:
            torch.distributed.send(tensor, self.ranks[dst], self.device_group)

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
        if src is None:
            src = (self.rank_in_group - 1) % self.world_size

        tensor = torch.empty(size, dtype=dtype, device=self.device)
        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.recv(tensor, src)
        else:
            torch.distributed.recv(tensor, self.ranks[src], self.device_group)
        return tensor

    def destroy(self):
        if self.pynccl_comm is not None:
            self.pynccl_comm = None
        if self.ca_comm is not None:
            self.ca_comm = None
        if self.all2all_manager is not None:
            self.all2all_manager.destroy()
            self.all2all_manager = None

    def dispatch(
            self, hidden_states: torch.Tensor,
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        assert self.all2all_manager is not None
        hidden_states, router_logits = self.all2all_manager.dispatch(
            hidden_states, router_logits)
        return hidden_states, router_logits

    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
        assert self.all2all_manager is not None
        hidden_states = self.all2all_manager.combine(hidden_states)
        return hidden_states

all2all_manager instance-attribute

all2all_manager = NaiveAll2AllManager(cpu_group)

ca_comm instance-attribute

ca_comm: Optional[CustomAllreduce] = None

pynccl_comm instance-attribute

pynccl_comm: Optional[PyNcclCommunicator] = None

qr_comm instance-attribute

qr_comm: Optional[QuickAllReduce] = None

use_custom_allreduce instance-attribute

use_custom_allreduce = use_custom_allreduce

use_pynccl instance-attribute

use_pynccl = use_pynccl

__init__

__init__(
    cpu_group: ProcessGroup,
    device: Optional[device] = None,
    device_group: Optional[ProcessGroup] = None,
    unique_name: str = "",
)
Source code in vllm/distributed/device_communicators/cuda_communicator.py
def __init__(self,
             cpu_group: ProcessGroup,
             device: Optional[torch.device] = None,
             device_group: Optional[ProcessGroup] = None,
             unique_name: str = ""):
    super().__init__(cpu_group, device, device_group, unique_name)
    if "tp" not in unique_name:
        # only tp uses custom allreduce
        use_custom_allreduce = False
    else:
        from vllm.distributed.parallel_state import (
            _ENABLE_CUSTOM_ALL_REDUCE)
        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE

    # ep does not use pynccl
    use_pynccl = "ep" not in unique_name

    self.use_pynccl = use_pynccl
    self.use_custom_allreduce = use_custom_allreduce

    # lazy import to avoid documentation build error
    from vllm.distributed.device_communicators.custom_all_reduce import (
        CustomAllreduce)
    from vllm.distributed.device_communicators.pynccl import (
        PyNcclCommunicator)
    from vllm.distributed.device_communicators.quick_all_reduce import (
        QuickAllReduce)

    self.pynccl_comm: Optional[PyNcclCommunicator] = None
    if use_pynccl and self.world_size > 1:
        self.pynccl_comm = PyNcclCommunicator(
            group=self.cpu_group,
            device=self.device,
        )

    self.ca_comm: Optional[CustomAllreduce] = None
    self.qr_comm: Optional[QuickAllReduce] = None
    if use_custom_allreduce and self.world_size > 1:
        # Initialize a custom fast all-reduce implementation.
        self.ca_comm = CustomAllreduce(
            group=self.cpu_group,
            device=self.device,
        )

        if current_platform.is_rocm():
            # Initialize a custom quick all-reduce implementation for AMD.
            # Quick reduce is designed as a complement to custom allreduce.
            # Based on quickreduce (https://github.com/mk1-project/quickreduce).
            # If it's a rocm, 'use_custom_allreduce==True' means it must
            # currently be an MI300 series.
            self.qr_comm = QuickAllReduce(group=self.cpu_group,
                                          device=self.device)
    if self.use_all2all:
        all2all_backend = envs.VLLM_ALL2ALL_BACKEND
        if all2all_backend == "naive":
            from .all2all import NaiveAll2AllManager
            self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
            logger.info("Using naive all2all manager.")
        elif all2all_backend == "pplx":
            from .all2all import PPLXAll2AllManager
            self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
            logger.info("Using PPLX all2all manager.")
        elif all2all_backend == "deepep_high_throughput":
            from .all2all import DeepEPHTAll2AllManager
            self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
            logger.info("Using DeepEP High-Throughput all2all manager.")
        elif all2all_backend == "deepep_low_latency":
            from .all2all import DeepEPLLAll2AllManager
            self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
            logger.info("Using DeepEP Low-Latency all2all manager.")
        else:
            raise ValueError(f"Unknown all2all backend: {all2all_backend}")

all_reduce

all_reduce(input_)
Source code in vllm/distributed/device_communicators/cuda_communicator.py
def all_reduce(self, input_):
    # always try quick reduce first, then custom allreduce,
    # and then pynccl. (quick reduce just for ROCM MI3*)
    qr_comm = self.qr_comm
    if qr_comm is not None and not qr_comm.disabled and \
        qr_comm.should_quick_allreduce(input_):
        out = qr_comm.quick_all_reduce(input_)
        assert out is not None
        return out
    ca_comm = self.ca_comm
    if ca_comm is not None and not ca_comm.disabled and \
        ca_comm.should_custom_ar(input_):
        out = ca_comm.custom_all_reduce(input_)
        assert out is not None
        return out
    pynccl_comm = self.pynccl_comm
    assert pynccl_comm is not None
    out = pynccl_comm.all_reduce(input_)
    if out is None:
        # fall back to the default all-reduce using PyTorch.
        # this usually happens during testing.
        # when we run the model, allreduce only happens for the TP
        # group, where we always have either custom allreduce or pynccl.
        out = input_.clone()
        torch.distributed.all_reduce(out, group=self.device_group)
    return out

combine

combine(hidden_states: Tensor) -> Tensor
Source code in vllm/distributed/device_communicators/cuda_communicator.py
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
    assert self.all2all_manager is not None
    hidden_states = self.all2all_manager.combine(hidden_states)
    return hidden_states

destroy

destroy()
Source code in vllm/distributed/device_communicators/cuda_communicator.py
def destroy(self):
    if self.pynccl_comm is not None:
        self.pynccl_comm = None
    if self.ca_comm is not None:
        self.ca_comm = None
    if self.all2all_manager is not None:
        self.all2all_manager.destroy()
        self.all2all_manager = None

dispatch

dispatch(
    hidden_states: Tensor, router_logits: Tensor
) -> tuple[Tensor, Tensor]
Source code in vllm/distributed/device_communicators/cuda_communicator.py
def dispatch(
        self, hidden_states: torch.Tensor,
        router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    assert self.all2all_manager is not None
    hidden_states, router_logits = self.all2all_manager.dispatch(
        hidden_states, router_logits)
    return hidden_states, router_logits

recv

recv(
    size: Size, dtype: dtype, src: Optional[int] = None
) -> Tensor

Receives a tensor from the source rank.

Source code in vllm/distributed/device_communicators/cuda_communicator.py
def recv(self,
         size: torch.Size,
         dtype: torch.dtype,
         src: Optional[int] = None) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    pynccl_comm = self.pynccl_comm
    if pynccl_comm is not None and not pynccl_comm.disabled:
        pynccl_comm.recv(tensor, src)
    else:
        torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor

reduce_scatter

reduce_scatter(input_: Tensor, dim: int = -1)
Source code in vllm/distributed/device_communicators/cuda_communicator.py
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
    world_size = self.world_size
    pynccl_comm = self.pynccl_comm
    assert pynccl_comm is not None
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Note: This will produce an incorrect answer if we don't make
    # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
    input_tensor = input_.movedim(0, dim).contiguous()

    assert input_tensor.shape[0] % world_size == 0
    chunk_size = input_tensor.shape[0] // world_size
    output_shape = (chunk_size, ) + input_tensor.shape[1:]

    output = torch.empty(output_shape,
                         dtype=input_tensor.dtype,
                         device=input_tensor.device)

    pynccl_comm.reduce_scatter(output, input_)

    # Reshape before returning
    return output.movedim(0, dim).contiguous()

send

send(tensor: Tensor, dst: Optional[int] = None) -> None

Sends a tensor to the destination rank in a non-blocking way

Source code in vllm/distributed/device_communicators/cuda_communicator.py
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
    """Sends a tensor to the destination rank in a non-blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size

    pynccl_comm = self.pynccl_comm
    if pynccl_comm is not None and not pynccl_comm.disabled:
        pynccl_comm.send(tensor, dst)
    else:
        torch.distributed.send(tensor, self.ranks[dst], self.device_group)