Skip to content

vllm.distributed.device_communicators.neuron_communicator

NeuronCommunicator

Bases: DeviceCommunicatorBase

Source code in vllm/distributed/device_communicators/neuron_communicator.py
class NeuronCommunicator(DeviceCommunicatorBase):

    def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
        return xm.all_reduce(xm.REDUCE_SUM, x)

    def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
        assert dim == -1, "Neuron only supports dim=-1 for all-gather."
        return xm.all_gather(x, dim=dim)

all_gather

all_gather(x: Tensor, dim: int = -1) -> Tensor
Source code in vllm/distributed/device_communicators/neuron_communicator.py
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
    assert dim == -1, "Neuron only supports dim=-1 for all-gather."
    return xm.all_gather(x, dim=dim)

all_reduce

all_reduce(x: Tensor) -> Tensor
Source code in vllm/distributed/device_communicators/neuron_communicator.py
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
    return xm.all_reduce(xm.REDUCE_SUM, x)