Bases: DeviceCommunicatorBase
Source code in vllm/distributed/device_communicators/neuron_communicator.py
| class NeuronCommunicator(DeviceCommunicatorBase):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
|
all_gather
Source code in vllm/distributed/device_communicators/neuron_communicator.py
| def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
|
all_reduce
Source code in vllm/distributed/device_communicators/neuron_communicator.py
| def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
|