Skip to content

vllm.model_executor.layers.fused_moe.pplx_prepare_finalize

PplxPrepareAndFinalize

Bases: FusedMoEPrepareAndFinalize

Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):

    def __init__(
        self,
        a2a: pplx.AllToAll,
        max_num_tokens: int,
        num_local_experts: int,
        num_dispatchers: int,
    ):
        super().__init__()
        assert max_num_tokens > 0
        assert num_local_experts > 0
        self.a2a = a2a
        self.max_num_tokens = max_num_tokens
        self.num_local_experts = num_local_experts
        self.num_dispatchers_ = num_dispatchers

    @property
    def activation_format(self) -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.BatchedExperts

    def max_num_tokens_per_rank(self) -> Optional[int]:
        return self.max_num_tokens

    def topk_indices_dtype(self) -> Optional[torch.dtype]:
        return torch.uint32

    def num_dispatchers(self) -> int:
        return self.num_dispatchers_

    def prepare(
        self,
        a1: torch.Tensor,
        a1_scale: Optional[torch.Tensor],
        a2_scale: Optional[torch.Tensor],
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        num_experts: int,
        expert_map: Optional[torch.Tensor],
        apply_router_weight_on_input: bool,
        quant_config: FusedMoEQuantConfig,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
               Optional[torch.Tensor], Optional[torch.Tensor]]:
        num_tokens = a1.size(0)  # M
        hidden_dim = a1.size(-1)  # K

        assert topk_ids.size(0) == num_tokens
        # assert expert_map is None, "NYI"

        # Is this always going to be a1.device?
        device = a1.device

        if apply_router_weight_on_input:
            topk = topk_ids.size(1)
            # TODO: this only works for topK=1, will need to update for topK>1
            assert topk == 1, (
                "apply_router_weight_on_input is only implemented for topk=1")
            a1 = a1 * topk_weights.to(a1.dtype)

        repeat_cols = 4
        repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
        a1q, a1q_scale = moe_kernel_quantize_input(
            a1, (None if quant_config.per_act_token_quant else a1_scale),
            quant_dtype=quant_config.quant_dtype,
            per_act_token_quant=quant_config.per_act_token_quant,
            block_shape=quant_config.block_shape)

        _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant,
                              quant_config.block_shape)

        if a1q_scale is not None:
            scalar_scales = a1q_scale.numel() == 1

            # pplx requires 2-d scales even for scalar scales
            if a1q_scale.dim() <= 1:
                assert scalar_scales
                a1q_scale = a1q_scale.view(1, 1)

            orig_a_scale_block_shape = a1q_scale.shape[-1]

            if not quant_config.is_block_quantized:
                # TODO (bnell): use group_broadcast instead?
                a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)

        assert a1q_scale is None or a1q_scale.ndim == 2, \
            f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"

        expert_num_tokens = torch.empty(
            self.num_local_experts,
            dtype=torch.int32,
            device=device,
        )

        expert_x = torch.empty(
            (self.num_local_experts,
             self.max_num_tokens * self.num_dispatchers(), hidden_dim),
            dtype=a1q.dtype,
            device=device,
        )

        expert_x_scale: Optional[torch.Tensor] = None
        if a1q.dtype.itemsize == 1:
            if quant_config.is_per_act_token:
                # (M x 1) -> (E x M x K)
                final_dim = expert_x.size(2)
            elif quant_config.is_per_tensor:
                # (1 x 1) -> (E x 1 x 1)
                final_dim = 1
            else:
                # (M x K_tiles) -> (E x M x K_tiles)
                assert quant_config.block_shape is not None
                num_blocks = cdiv(expert_x.size(2),
                                  quant_config.block_shape[1])
                final_dim = num_blocks

            expert_x_scale_shape = (
                self.num_local_experts,
                expert_x.size(1),
                round_up(final_dim, 4)  # round up for alignment
            )

            expert_x_scale = torch.empty(
                expert_x_scale_shape,
                dtype=torch.float32,
                device=expert_x.device,
            )

        # This argument is optional, defaults to indices.size(0)
        # There's not much point setting this unless it is != indices.size(0)
        bound_m: Optional[torch.Tensor] = None

        self.a2a.dispatch(
            out_expert_num_tokens=expert_num_tokens,
            out_expert_x=expert_x,
            out_expert_x_scale=expert_x_scale,
            dp_x=a1q,
            dp_x_scale=a1q_scale,
            indices=topk_ids,
            bound_m=bound_m,
        )

        if expert_x_scale is not None:
            expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
            assert expert_x_scale.ndim == 3

        return expert_x, expert_x_scale, expert_num_tokens, None, None

    def finalize(
        self,
        output: torch.Tensor,
        fused_expert_output: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        apply_router_weight_on_input: bool,
    ) -> None:
        # This argument is optional
        # There's not much point setting this unless it is != topk_ids.size(0)
        bound_m: Optional[torch.Tensor] = None

        # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
        #num_tokens = output.size(0)  # M
        #assert topk_ids.size(0) == num_tokens, (
        #    f"{topk_ids.size(0)} == {num_tokens}")
        assert topk_ids.size() == topk_weights.size(), (
            f"{topk_ids.size()} == {topk_weights.size()}")
        assert output.size(0) <= self.max_num_tokens, (
            f"{output.size(0)} <= {self.max_num_tokens}")
        assert output.size(1) == fused_expert_output.size(-1)

        # Set weights to 1 if we did them in dispatch. This is hacky.
        if apply_router_weight_on_input:
            topk_weights = torch.ones_like(topk_weights)

        self.a2a.combine(out_tokens=output,
                         indices=topk_ids,
                         weights=topk_weights,
                         expert_y=fused_expert_output,
                         bound_m=bound_m)

a2a instance-attribute

a2a = a2a

activation_format property

activation_format: FusedMoEActivationFormat

max_num_tokens instance-attribute

max_num_tokens = max_num_tokens

num_dispatchers_ instance-attribute

num_dispatchers_ = num_dispatchers

num_local_experts instance-attribute

num_local_experts = num_local_experts

__init__

__init__(
    a2a: AllToAll,
    max_num_tokens: int,
    num_local_experts: int,
    num_dispatchers: int,
)
Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
def __init__(
    self,
    a2a: pplx.AllToAll,
    max_num_tokens: int,
    num_local_experts: int,
    num_dispatchers: int,
):
    super().__init__()
    assert max_num_tokens > 0
    assert num_local_experts > 0
    self.a2a = a2a
    self.max_num_tokens = max_num_tokens
    self.num_local_experts = num_local_experts
    self.num_dispatchers_ = num_dispatchers

finalize

finalize(
    output: Tensor,
    fused_expert_output: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    apply_router_weight_on_input: bool,
) -> None
Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
def finalize(
    self,
    output: torch.Tensor,
    fused_expert_output: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    apply_router_weight_on_input: bool,
) -> None:
    # This argument is optional
    # There's not much point setting this unless it is != topk_ids.size(0)
    bound_m: Optional[torch.Tensor] = None

    # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
    #num_tokens = output.size(0)  # M
    #assert topk_ids.size(0) == num_tokens, (
    #    f"{topk_ids.size(0)} == {num_tokens}")
    assert topk_ids.size() == topk_weights.size(), (
        f"{topk_ids.size()} == {topk_weights.size()}")
    assert output.size(0) <= self.max_num_tokens, (
        f"{output.size(0)} <= {self.max_num_tokens}")
    assert output.size(1) == fused_expert_output.size(-1)

    # Set weights to 1 if we did them in dispatch. This is hacky.
    if apply_router_weight_on_input:
        topk_weights = torch.ones_like(topk_weights)

    self.a2a.combine(out_tokens=output,
                     indices=topk_ids,
                     weights=topk_weights,
                     expert_y=fused_expert_output,
                     bound_m=bound_m)

max_num_tokens_per_rank

max_num_tokens_per_rank() -> Optional[int]
Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
def max_num_tokens_per_rank(self) -> Optional[int]:
    return self.max_num_tokens

num_dispatchers

num_dispatchers() -> int
Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
def num_dispatchers(self) -> int:
    return self.num_dispatchers_

prepare

prepare(
    a1: Tensor,
    a1_scale: Optional[Tensor],
    a2_scale: Optional[Tensor],
    topk_weights: Tensor,
    topk_ids: Tensor,
    num_experts: int,
    expert_map: Optional[Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> tuple[
    Tensor,
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
    Optional[Tensor],
]
Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
def prepare(
    self,
    a1: torch.Tensor,
    a1_scale: Optional[torch.Tensor],
    a2_scale: Optional[torch.Tensor],
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    expert_map: Optional[torch.Tensor],
    apply_router_weight_on_input: bool,
    quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
           Optional[torch.Tensor], Optional[torch.Tensor]]:
    num_tokens = a1.size(0)  # M
    hidden_dim = a1.size(-1)  # K

    assert topk_ids.size(0) == num_tokens
    # assert expert_map is None, "NYI"

    # Is this always going to be a1.device?
    device = a1.device

    if apply_router_weight_on_input:
        topk = topk_ids.size(1)
        # TODO: this only works for topK=1, will need to update for topK>1
        assert topk == 1, (
            "apply_router_weight_on_input is only implemented for topk=1")
        a1 = a1 * topk_weights.to(a1.dtype)

    repeat_cols = 4
    repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
    a1q, a1q_scale = moe_kernel_quantize_input(
        a1, (None if quant_config.per_act_token_quant else a1_scale),
        quant_dtype=quant_config.quant_dtype,
        per_act_token_quant=quant_config.per_act_token_quant,
        block_shape=quant_config.block_shape)

    _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant,
                          quant_config.block_shape)

    if a1q_scale is not None:
        scalar_scales = a1q_scale.numel() == 1

        # pplx requires 2-d scales even for scalar scales
        if a1q_scale.dim() <= 1:
            assert scalar_scales
            a1q_scale = a1q_scale.view(1, 1)

        orig_a_scale_block_shape = a1q_scale.shape[-1]

        if not quant_config.is_block_quantized:
            # TODO (bnell): use group_broadcast instead?
            a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)

    assert a1q_scale is None or a1q_scale.ndim == 2, \
        f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"

    expert_num_tokens = torch.empty(
        self.num_local_experts,
        dtype=torch.int32,
        device=device,
    )

    expert_x = torch.empty(
        (self.num_local_experts,
         self.max_num_tokens * self.num_dispatchers(), hidden_dim),
        dtype=a1q.dtype,
        device=device,
    )

    expert_x_scale: Optional[torch.Tensor] = None
    if a1q.dtype.itemsize == 1:
        if quant_config.is_per_act_token:
            # (M x 1) -> (E x M x K)
            final_dim = expert_x.size(2)
        elif quant_config.is_per_tensor:
            # (1 x 1) -> (E x 1 x 1)
            final_dim = 1
        else:
            # (M x K_tiles) -> (E x M x K_tiles)
            assert quant_config.block_shape is not None
            num_blocks = cdiv(expert_x.size(2),
                              quant_config.block_shape[1])
            final_dim = num_blocks

        expert_x_scale_shape = (
            self.num_local_experts,
            expert_x.size(1),
            round_up(final_dim, 4)  # round up for alignment
        )

        expert_x_scale = torch.empty(
            expert_x_scale_shape,
            dtype=torch.float32,
            device=expert_x.device,
        )

    # This argument is optional, defaults to indices.size(0)
    # There's not much point setting this unless it is != indices.size(0)
    bound_m: Optional[torch.Tensor] = None

    self.a2a.dispatch(
        out_expert_num_tokens=expert_num_tokens,
        out_expert_x=expert_x,
        out_expert_x_scale=expert_x_scale,
        dp_x=a1q,
        dp_x_scale=a1q_scale,
        indices=topk_ids,
        bound_m=bound_m,
    )

    if expert_x_scale is not None:
        expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
        assert expert_x_scale.ndim == 3

    return expert_x, expert_x_scale, expert_num_tokens, None, None

topk_indices_dtype

topk_indices_dtype() -> Optional[dtype]
Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
def topk_indices_dtype(self) -> Optional[torch.dtype]:
    return torch.uint32

pplx_hidden_dim_scale_bytes

pplx_hidden_dim_scale_bytes(
    max_num_tokens: int,
    hidden_dim: int,
    in_dtype: dtype,
    quant_dtype: Optional[dtype],
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
)
Source code in vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
def pplx_hidden_dim_scale_bytes(
    max_num_tokens: int,
    hidden_dim: int,
    in_dtype: torch.dtype,
    quant_dtype: Optional[torch.dtype],
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
):
    # All pplx byte sizes must be 16-byte aligned.
    align = 16

    # For blocked per token: set to
    #   ceil_div(hidden_dim, block_size) * sizeof(float32)
    # For per-token: set to 4 * sizeof(float32) (x4 for alignment)
    if quant_dtype is not None:
        assert quant_dtype.itemsize == 1
        hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
        elem_size = torch.float32.itemsize

        if per_act_token_quant:
            # per-token (M x 1)
            assert block_shape is None
            hidden_scale_bytes = elem_size
        elif block_shape is not None:
            # per-group (M x K_tiles)
            block_size = block_shape[1]
            num_blocks = cdiv(hidden_dim, block_size)
            hidden_scale_bytes = num_blocks * elem_size
        else:
            # per-tensor (1 x 1)
            hidden_scale_bytes = elem_size
    else:
        hidden_dim_bytes = hidden_dim * in_dtype.itemsize
        hidden_scale_bytes = 0

    return (
        round_up(hidden_dim_bytes, align),
        round_up(hidden_scale_bytes, align),
    )