Skip to content

vllm.distributed.device_communicators.pynccl_wrapper

__all__ module-attribute

__all__ = [
    "NCCLLibrary",
    "ncclDataTypeEnum",
    "ncclRedOpTypeEnum",
    "ncclUniqueId",
    "ncclComm_t",
    "cudaStream_t",
    "buffer_type",
]

buffer_type module-attribute

buffer_type = c_void_p

cudaStream_t module-attribute

cudaStream_t = c_void_p

logger module-attribute

logger = init_logger(__name__)

ncclComm_t module-attribute

ncclComm_t = c_void_p

ncclDataType_t module-attribute

ncclDataType_t = c_int

ncclRedOp_t module-attribute

ncclRedOp_t = c_int

ncclResult_t module-attribute

ncclResult_t = c_int

Function dataclass

Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
@dataclass
class Function:
    name: str
    restype: Any
    argtypes: list[Any]

argtypes instance-attribute

argtypes: list[Any]

name instance-attribute

name: str

restype instance-attribute

restype: Any

__init__

__init__(
    name: str, restype: Any, argtypes: list[Any]
) -> None

NCCLLibrary

Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
class NCCLLibrary:
    exported_functions = [
        # const char* ncclGetErrorString(ncclResult_t result)
        Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
        # ncclResult_t  ncclGetVersion(int *version);
        Function("ncclGetVersion", ncclResult_t,
                 [ctypes.POINTER(ctypes.c_int)]),
        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
        Function("ncclGetUniqueId", ncclResult_t,
                 [ctypes.POINTER(ncclUniqueId)]),
        # ncclResult_t  ncclCommInitRank(
        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
        # note that ncclComm_t is a pointer type, so the first argument
        # is a pointer to a pointer
        Function("ncclCommInitRank", ncclResult_t, [
            ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
            ctypes.c_int
        ]),
        # ncclResult_t  ncclAllReduce(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
        #   cudaStream_t stream);
        # note that cudaStream_t is a pointer type, so the last argument
        # is a pointer
        Function("ncclAllReduce", ncclResult_t, [
            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
            ncclRedOp_t, ncclComm_t, cudaStream_t
        ]),

        # ncclResult_t  ncclAllGather(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, ncclComm_t comm,
        #   cudaStream_t stream);
        # note that cudaStream_t is a pointer type, so the last argument
        # is a pointer
        Function("ncclAllGather", ncclResult_t, [
            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
            ncclComm_t, cudaStream_t
        ]),

        # ncclResult_t  ncclReduceScatter(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
        #   cudaStream_t stream);
        # note that cudaStream_t is a pointer type, so the last argument
        # is a pointer
        Function("ncclReduceScatter", ncclResult_t, [
            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
            ncclRedOp_t, ncclComm_t, cudaStream_t
        ]),

        # ncclResult_t  ncclSend(
        #   const void* sendbuff, size_t count, ncclDataType_t datatype,
        #   int dest, ncclComm_t comm, cudaStream_t stream);
        Function("ncclSend", ncclResult_t, [
            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
            ncclComm_t, cudaStream_t
        ]),

        # ncclResult_t  ncclRecv(
        #   void* recvbuff, size_t count, ncclDataType_t datatype,
        #   int src, ncclComm_t comm, cudaStream_t stream);
        Function("ncclRecv", ncclResult_t, [
            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
            ncclComm_t, cudaStream_t
        ]),

        # ncclResult_t ncclBroadcast(
        #   const void* sendbuff, void* recvbuff, size_t count,
        #   ncclDataType_t datatype, int root, ncclComm_t comm,
        #   cudaStream_t stream);
        Function("ncclBroadcast", ncclResult_t, [
            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
            ctypes.c_int, ncclComm_t, cudaStream_t
        ]),

        # be cautious! this is a collective call, it will block until all
        # processes in the communicator have called this function.
        # because Python object destruction can happen in random order,
        # it is better not to call it at all.
        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);
        Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
    ]

    # class attribute to store the mapping from the path to the library
    # to avoid loading the same library multiple times
    path_to_library_cache: dict[str, Any] = {}

    # class attribute to store the mapping from library path
    #  to the corresponding dictionary
    path_to_dict_mapping: dict[str, dict[str, Any]] = {}

    def __init__(self, so_file: Optional[str] = None):

        so_file = so_file or find_nccl_library()

        try:
            if so_file not in NCCLLibrary.path_to_dict_mapping:
                lib = ctypes.CDLL(so_file)
                NCCLLibrary.path_to_library_cache[so_file] = lib
            self.lib = NCCLLibrary.path_to_library_cache[so_file]
        except Exception as e:
            logger.error(
                "Failed to load NCCL library from %s. "
                "It is expected if you are not running on NVIDIA/AMD GPUs."
                "Otherwise, the nccl library might not exist, be corrupted "
                "or it does not support the current platform %s. "
                "If you already have the library, please set the "
                "environment variable VLLM_NCCL_SO_PATH"
                " to point to the correct nccl library path.", so_file,
                platform.platform())
            raise e

        if so_file not in NCCLLibrary.path_to_dict_mapping:
            _funcs: dict[str, Any] = {}
            for func in NCCLLibrary.exported_functions:
                f = getattr(self.lib, func.name)
                f.restype = func.restype
                f.argtypes = func.argtypes
                _funcs[func.name] = f
            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]

    def ncclGetErrorString(self, result: ncclResult_t) -> str:
        return self._funcs["ncclGetErrorString"](result).decode("utf-8")

    def NCCL_CHECK(self, result: ncclResult_t) -> None:
        if result != 0:
            error_str = self.ncclGetErrorString(result)
            raise RuntimeError(f"NCCL error: {error_str}")

    def ncclGetVersion(self) -> str:
        version = ctypes.c_int()
        self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
        version_str = str(version.value)
        # something like 21903 --> "2.19.3"
        major = version_str[0].lstrip("0")
        minor = version_str[1:3].lstrip("0")
        patch = version_str[3:].lstrip("0")
        return f"{major}.{minor}.{patch}"

    def ncclGetUniqueId(self) -> ncclUniqueId:
        unique_id = ncclUniqueId()
        self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
            ctypes.byref(unique_id)))
        return unique_id

    def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
        if len(data) != 128:
            raise ValueError(
                f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
        unique_id = ncclUniqueId()
        ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
        return unique_id

    def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
                         rank: int) -> ncclComm_t:
        comm = ncclComm_t()
        self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
                                                        world_size, unique_id,
                                                        rank))
        return comm

    def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
                      count: int, datatype: int, op: int, comm: ncclComm_t,
                      stream: cudaStream_t) -> None:
        # `datatype` actually should be `ncclDataType_t`
        # and `op` should be `ncclRedOp_t`
        # both are aliases of `ctypes.c_int`
        # when we pass int to a function, it will be converted to `ctypes.c_int`
        # by ctypes automatically
        self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
                                                     datatype, op, comm,
                                                     stream))

    def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
                          count: int, datatype: int, op: int, comm: ncclComm_t,
                          stream: cudaStream_t) -> None:
        # `datatype` actually should be `ncclDataType_t`
        # and `op` should be `ncclRedOp_t`
        # both are aliases of `ctypes.c_int`
        # when we pass int to a function, it will be converted to `ctypes.c_int`
        # by ctypes automatically
        self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
                                                         count, datatype, op,
                                                         comm, stream))

    def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
                      count: int, datatype: int, comm: ncclComm_t,
                      stream: cudaStream_t) -> None:
        # `datatype` actually should be `ncclDataType_t`
        # which is an aliases of `ctypes.c_int`
        # when we pass int to a function, it will be converted to `ctypes.c_int`
        # by ctypes automatically
        self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
                                                     datatype, comm, stream))

    def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
                 dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
        self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
                                                dest, comm, stream))

    def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
                 src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
        self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
                                                comm, stream))

    def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
                      count: int, datatype: int, root: int, comm: ncclComm_t,
                      stream: cudaStream_t) -> None:
        self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
                                                     datatype, root, comm,
                                                     stream))

    def ncclCommDestroy(self, comm: ncclComm_t) -> None:
        self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

_funcs instance-attribute

_funcs = path_to_dict_mapping[so_file]

exported_functions class-attribute instance-attribute

lib instance-attribute

lib = path_to_library_cache[so_file]

path_to_dict_mapping class-attribute instance-attribute

path_to_dict_mapping: dict[str, dict[str, Any]] = {}

path_to_library_cache class-attribute instance-attribute

path_to_library_cache: dict[str, Any] = {}

NCCL_CHECK

NCCL_CHECK(result: ncclResult_t) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def NCCL_CHECK(self, result: ncclResult_t) -> None:
    if result != 0:
        error_str = self.ncclGetErrorString(result)
        raise RuntimeError(f"NCCL error: {error_str}")

__init__

__init__(so_file: Optional[str] = None)
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def __init__(self, so_file: Optional[str] = None):

    so_file = so_file or find_nccl_library()

    try:
        if so_file not in NCCLLibrary.path_to_dict_mapping:
            lib = ctypes.CDLL(so_file)
            NCCLLibrary.path_to_library_cache[so_file] = lib
        self.lib = NCCLLibrary.path_to_library_cache[so_file]
    except Exception as e:
        logger.error(
            "Failed to load NCCL library from %s. "
            "It is expected if you are not running on NVIDIA/AMD GPUs."
            "Otherwise, the nccl library might not exist, be corrupted "
            "or it does not support the current platform %s. "
            "If you already have the library, please set the "
            "environment variable VLLM_NCCL_SO_PATH"
            " to point to the correct nccl library path.", so_file,
            platform.platform())
        raise e

    if so_file not in NCCLLibrary.path_to_dict_mapping:
        _funcs: dict[str, Any] = {}
        for func in NCCLLibrary.exported_functions:
            f = getattr(self.lib, func.name)
            f.restype = func.restype
            f.argtypes = func.argtypes
            _funcs[func.name] = f
        NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
    self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]

ncclAllGather

ncclAllGather(
    sendbuff: buffer_type,
    recvbuff: buffer_type,
    count: int,
    datatype: int,
    comm: ncclComm_t,
    stream: cudaStream_t,
) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
                  count: int, datatype: int, comm: ncclComm_t,
                  stream: cudaStream_t) -> None:
    # `datatype` actually should be `ncclDataType_t`
    # which is an aliases of `ctypes.c_int`
    # when we pass int to a function, it will be converted to `ctypes.c_int`
    # by ctypes automatically
    self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
                                                 datatype, comm, stream))

ncclAllReduce

ncclAllReduce(
    sendbuff: buffer_type,
    recvbuff: buffer_type,
    count: int,
    datatype: int,
    op: int,
    comm: ncclComm_t,
    stream: cudaStream_t,
) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
                  count: int, datatype: int, op: int, comm: ncclComm_t,
                  stream: cudaStream_t) -> None:
    # `datatype` actually should be `ncclDataType_t`
    # and `op` should be `ncclRedOp_t`
    # both are aliases of `ctypes.c_int`
    # when we pass int to a function, it will be converted to `ctypes.c_int`
    # by ctypes automatically
    self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
                                                 datatype, op, comm,
                                                 stream))

ncclBroadcast

ncclBroadcast(
    sendbuff: buffer_type,
    recvbuff: buffer_type,
    count: int,
    datatype: int,
    root: int,
    comm: ncclComm_t,
    stream: cudaStream_t,
) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
                  count: int, datatype: int, root: int, comm: ncclComm_t,
                  stream: cudaStream_t) -> None:
    self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
                                                 datatype, root, comm,
                                                 stream))

ncclCommDestroy

ncclCommDestroy(comm: ncclComm_t) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
    self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

ncclCommInitRank

ncclCommInitRank(
    world_size: int, unique_id: ncclUniqueId, rank: int
) -> ncclComm_t
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
                     rank: int) -> ncclComm_t:
    comm = ncclComm_t()
    self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
                                                    world_size, unique_id,
                                                    rank))
    return comm

ncclGetErrorString

ncclGetErrorString(result: ncclResult_t) -> str
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclGetErrorString(self, result: ncclResult_t) -> str:
    return self._funcs["ncclGetErrorString"](result).decode("utf-8")

ncclGetUniqueId

ncclGetUniqueId() -> ncclUniqueId
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclGetUniqueId(self) -> ncclUniqueId:
    unique_id = ncclUniqueId()
    self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
        ctypes.byref(unique_id)))
    return unique_id

ncclGetVersion

ncclGetVersion() -> str
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclGetVersion(self) -> str:
    version = ctypes.c_int()
    self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
    version_str = str(version.value)
    # something like 21903 --> "2.19.3"
    major = version_str[0].lstrip("0")
    minor = version_str[1:3].lstrip("0")
    patch = version_str[3:].lstrip("0")
    return f"{major}.{minor}.{patch}"

ncclRecv

ncclRecv(
    recvbuff: buffer_type,
    count: int,
    datatype: int,
    src: int,
    comm: ncclComm_t,
    stream: cudaStream_t,
) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
             src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
    self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
                                            comm, stream))

ncclReduceScatter

ncclReduceScatter(
    sendbuff: buffer_type,
    recvbuff: buffer_type,
    count: int,
    datatype: int,
    op: int,
    comm: ncclComm_t,
    stream: cudaStream_t,
) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
                      count: int, datatype: int, op: int, comm: ncclComm_t,
                      stream: cudaStream_t) -> None:
    # `datatype` actually should be `ncclDataType_t`
    # and `op` should be `ncclRedOp_t`
    # both are aliases of `ctypes.c_int`
    # when we pass int to a function, it will be converted to `ctypes.c_int`
    # by ctypes automatically
    self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
                                                     count, datatype, op,
                                                     comm, stream))

ncclSend

ncclSend(
    sendbuff: buffer_type,
    count: int,
    datatype: int,
    dest: int,
    comm: ncclComm_t,
    stream: cudaStream_t,
) -> None
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
             dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
    self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
                                            dest, comm, stream))

unique_id_from_bytes

unique_id_from_bytes(data: bytes) -> ncclUniqueId
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
    if len(data) != 128:
        raise ValueError(
            f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
    unique_id = ncclUniqueId()
    ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
    return unique_id

ncclDataTypeEnum

Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
class ncclDataTypeEnum:
    ncclInt8 = 0
    ncclChar = 0
    ncclUint8 = 1
    ncclInt32 = 2
    ncclInt = 2
    ncclUint32 = 3
    ncclInt64 = 4
    ncclUint64 = 5
    ncclFloat16 = 6
    ncclHalf = 6
    ncclFloat32 = 7
    ncclFloat = 7
    ncclFloat64 = 8
    ncclDouble = 8
    ncclBfloat16 = 9
    ncclNumTypes = 10

    @classmethod
    def from_torch(cls, dtype: torch.dtype) -> int:
        if dtype == torch.int8:
            return cls.ncclInt8
        if dtype == torch.uint8:
            return cls.ncclUint8
        if dtype == torch.int32:
            return cls.ncclInt32
        if dtype == torch.int64:
            return cls.ncclInt64
        if dtype == torch.float16:
            return cls.ncclFloat16
        if dtype == torch.float32:
            return cls.ncclFloat32
        if dtype == torch.float64:
            return cls.ncclFloat64
        if dtype == torch.bfloat16:
            return cls.ncclBfloat16
        raise ValueError(f"Unsupported dtype: {dtype}")

ncclBfloat16 class-attribute instance-attribute

ncclBfloat16 = 9

ncclChar class-attribute instance-attribute

ncclChar = 0

ncclDouble class-attribute instance-attribute

ncclDouble = 8

ncclFloat class-attribute instance-attribute

ncclFloat = 7

ncclFloat16 class-attribute instance-attribute

ncclFloat16 = 6

ncclFloat32 class-attribute instance-attribute

ncclFloat32 = 7

ncclFloat64 class-attribute instance-attribute

ncclFloat64 = 8

ncclHalf class-attribute instance-attribute

ncclHalf = 6

ncclInt class-attribute instance-attribute

ncclInt = 2

ncclInt32 class-attribute instance-attribute

ncclInt32 = 2

ncclInt64 class-attribute instance-attribute

ncclInt64 = 4

ncclInt8 class-attribute instance-attribute

ncclInt8 = 0

ncclNumTypes class-attribute instance-attribute

ncclNumTypes = 10

ncclUint32 class-attribute instance-attribute

ncclUint32 = 3

ncclUint64 class-attribute instance-attribute

ncclUint64 = 5

ncclUint8 class-attribute instance-attribute

ncclUint8 = 1

from_torch classmethod

from_torch(dtype: dtype) -> int
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
    if dtype == torch.int8:
        return cls.ncclInt8
    if dtype == torch.uint8:
        return cls.ncclUint8
    if dtype == torch.int32:
        return cls.ncclInt32
    if dtype == torch.int64:
        return cls.ncclInt64
    if dtype == torch.float16:
        return cls.ncclFloat16
    if dtype == torch.float32:
        return cls.ncclFloat32
    if dtype == torch.float64:
        return cls.ncclFloat64
    if dtype == torch.bfloat16:
        return cls.ncclBfloat16
    raise ValueError(f"Unsupported dtype: {dtype}")

ncclRedOpTypeEnum

Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
class ncclRedOpTypeEnum:
    ncclSum = 0
    ncclProd = 1
    ncclMax = 2
    ncclMin = 3
    ncclAvg = 4
    ncclNumOps = 5

    @classmethod
    def from_torch(cls, op: ReduceOp) -> int:
        if op == ReduceOp.SUM:
            return cls.ncclSum
        if op == ReduceOp.PRODUCT:
            return cls.ncclProd
        if op == ReduceOp.MAX:
            return cls.ncclMax
        if op == ReduceOp.MIN:
            return cls.ncclMin
        if op == ReduceOp.AVG:
            return cls.ncclAvg
        raise ValueError(f"Unsupported op: {op}")

ncclAvg class-attribute instance-attribute

ncclAvg = 4

ncclMax class-attribute instance-attribute

ncclMax = 2

ncclMin class-attribute instance-attribute

ncclMin = 3

ncclNumOps class-attribute instance-attribute

ncclNumOps = 5

ncclProd class-attribute instance-attribute

ncclProd = 1

ncclSum class-attribute instance-attribute

ncclSum = 0

from_torch classmethod

from_torch(op: ReduceOp) -> int
Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
    if op == ReduceOp.SUM:
        return cls.ncclSum
    if op == ReduceOp.PRODUCT:
        return cls.ncclProd
    if op == ReduceOp.MAX:
        return cls.ncclMax
    if op == ReduceOp.MIN:
        return cls.ncclMin
    if op == ReduceOp.AVG:
        return cls.ncclAvg
    raise ValueError(f"Unsupported op: {op}")

ncclUniqueId

Bases: Structure

Source code in vllm/distributed/device_communicators/pynccl_wrapper.py
class ncclUniqueId(ctypes.Structure):
    _fields_ = [("internal", ctypes.c_byte * 128)]

_fields_ class-attribute instance-attribute

_fields_ = [('internal', c_byte * 128)]