Skip to content

vllm.distributed.device_communicators.flashinfer_all_reduce

FlashInferAllReduce

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
class FlashInferAllReduce:
    def __init__(
        self,
        group: ProcessGroup,
        device: int | str | torch.device,
    ):
        self.disabled = True

        if not fi_ar_available:
            logger.info(
                "FlashInfer All Reduce is disabled because flashinfer is not available"
            )
            return

        if not current_platform.is_cuda():
            logger.info(
                "FlashInfer All Reduce is disabled because it requires CUDA platform"
            )
            return

        self.group = group
        self.world_size = dist.get_world_size(self.group)
        self.rank = dist.get_rank(self.group)
        self.device = device
        if self.world_size == 1:
            return

        # Use the same threshold as the allreduce-rms fusion pass
        # TODO: tune the threshold
        MiB = 1024 * 1024
        max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get(
            self.world_size, None
        )
        if not max_workspace_size:
            logger.warning(
                "FlashInfer All Reduce is disabled because it "
                "is not supported for world_size=%d.",
                self.world_size,
            )
            return
        self.max_workspace_size = max_workspace_size * MiB
        self.max_num_tokens = 0
        self.disabled = False

    def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
        """Ensure the all reduce workspace is initialized."""
        if get_fi_ar_workspace() is not None:
            return True
        if self.max_num_tokens == 0:
            element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
            self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
        try:
            initialize_fi_ar_workspace(
                world_size=self.world_size,
                rank=self.rank,
                max_token_num=self.max_num_tokens,
                hidden_dim=hidden_dim,
                dtype=dtype,
                group=self.group,
            )
            return True
        except Exception as e:
            logger.warning(
                "Failed to initialize FlashInfer All Reduce workspace: %s. "
                "FlashInfer All Reduce will be disabled.",
                e,
            )
            self.disabled = True
            return False

    def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool:
        if self.disabled:
            return False

        if not input_tensor.is_cuda:
            return False

        if not input_tensor.is_contiguous():
            return False

        if len(input_tensor.shape) != 2:
            return False

        num_tokens, hidden_dim = input_tensor.shape
        if not self.max_num_tokens:
            element_size = torch.tensor([], dtype=input_tensor.dtype).element_size()
            self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)

        if num_tokens > self.max_num_tokens:
            return False

        return self._ensure_workspace(hidden_dim, input_tensor.dtype)

    def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor:
        workspace = get_fi_ar_workspace()
        return flashinfer_comm.allreduce_fusion(
            input=input_tensor,
            workspace=workspace,
            pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce,
        )

    def destroy(self):
        if not self.disabled:
            destroy_fi_ar_workspace()

_ensure_workspace

_ensure_workspace(hidden_dim: int, dtype: dtype) -> bool

Ensure the all reduce workspace is initialized.

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool:
    """Ensure the all reduce workspace is initialized."""
    if get_fi_ar_workspace() is not None:
        return True
    if self.max_num_tokens == 0:
        element_size = torch.tensor([], dtype=dtype, device="cpu").element_size()
        self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size)
    try:
        initialize_fi_ar_workspace(
            world_size=self.world_size,
            rank=self.rank,
            max_token_num=self.max_num_tokens,
            hidden_dim=hidden_dim,
            dtype=dtype,
            group=self.group,
        )
        return True
    except Exception as e:
        logger.warning(
            "Failed to initialize FlashInfer All Reduce workspace: %s. "
            "FlashInfer All Reduce will be disabled.",
            e,
        )
        self.disabled = True
        return False

initialize_fi_ar_quant_workspace

initialize_fi_ar_quant_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: dtype,
    group: ProcessGroup,
) -> None

Initialize the workspace used by quantization fusion patterns.

Currently this always creates a workspace for trtllm backend as only it supports quantization fusion (FP8/FP4). If the primary workspace is already trtllm, the quant workspace aliases to it.

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
def initialize_fi_ar_quant_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: torch.dtype,
    group: ProcessGroup,
) -> None:
    """
    Initialize the workspace used by quantization fusion patterns.

    Currently this always creates a workspace for trtllm backend as only it
    supports quantization fusion (FP8/FP4). If the primary workspace
    is already trtllm, the quant workspace aliases to it.
    """
    global _fi_ar_quant_workspace
    if _fi_ar_quant_workspace is not None:
        return

    # If primary workspace is already trtllm, reuse it
    if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm":
        _fi_ar_quant_workspace = _fi_ar_workspace
        return

    comm_backend = TorchDistBackend(group=group)
    _fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
        backend="trtllm",
        world_size=world_size,
        rank=rank,
        max_token_num=max_token_num,
        hidden_dim=hidden_dim,
        dtype=dtype,
        comm_backend=comm_backend,
    )
    assert _fi_ar_quant_workspace is not None
    logger.debug(
        "Initialized FlashInfer All Reduce workspace: backend=trtllm, "
        "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
        world_size,
        rank,
        max_token_num,
        hidden_dim,
        dtype,
    )

initialize_fi_ar_workspace

initialize_fi_ar_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: dtype,
    group: ProcessGroup,
) -> None

Initialize the workspace if not already initialized.

Currently, this function is called by either the AllReduceFusionPass or the FlashInferAllReduce backend for standalone allreduce. If the fusion pass is enabled via --compilation-config.pass_config.fuse_allreduce_rms=true, it will create the workspace first, and the standalone backend will reuse the workspace. Otherwise, the standalone backend will create the workspace.

Source code in vllm/distributed/device_communicators/flashinfer_all_reduce.py
def initialize_fi_ar_workspace(
    world_size: int,
    rank: int,
    max_token_num: int,
    hidden_dim: int,
    dtype: torch.dtype,
    group: ProcessGroup,
) -> None:
    """
    Initialize the workspace if not already initialized.

    Currently, this function is called by either the AllReduceFusionPass
    or the FlashInferAllReduce backend for standalone allreduce.
    If the fusion pass is enabled via
    --compilation-config.pass_config.fuse_allreduce_rms=true,
    it will create the workspace first, and the standalone backend
    will reuse the workspace. Otherwise, the standalone backend will
    create the workspace.
    """
    global _fi_ar_workspace
    if _fi_ar_workspace is not None:
        return

    backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
    comm_backend = TorchDistBackend(group=group)
    _fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
        backend=backend,
        world_size=world_size,
        rank=rank,
        max_token_num=max_token_num,
        hidden_dim=hidden_dim,
        dtype=dtype,
        comm_backend=comm_backend,
    )
    assert _fi_ar_workspace is not None
    logger.debug(
        "Initialized FlashInfer All Reduce workspace: backend=%s, "
        "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s",
        backend,
        world_size,
        rank,
        max_token_num,
        hidden_dim,
        dtype,
    )