class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""FlashInfer implementation using the Moe AlltoAll kernel."""
def __init__(
self,
max_num_tokens: int,
top_k: int,
num_experts: int,
hidden_size: int,
num_dispatchers: int = 1,
):
super().__init__()
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.num_experts = num_experts
self.hidden_size = hidden_size
self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
self.all2all_manager.initialize(
max_num_tokens=self.max_num_tokens,
top_k=self.top_k,
num_experts=self.num_experts,
hidden_size=self.hidden_size,
)
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int32
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1.mul_(topk_weights.to(a1.dtype))
global_num_tokens_cpu = get_local_sizes()
self.runtime_max_tokens_per_rank = (
max(global_num_tokens_cpu)
if global_num_tokens_cpu is not None
else a1.shape[0]
)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm
)
payloads = []
payloads.append(a1q)
if a1q_scale is not None:
payloads.append(a1q_scale)
payloads.append(topk_ids)
payloads.append(topk_weights)
recv_payloads = self.all2all_manager.moe_alltoall.dispatch(
token_selected_experts=topk_ids,
input_payloads=payloads,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
)
if a1q_scale is not None:
a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads
# Apply scale interleaving only for CUTLASS (not TRT-LLM)
if (
quant_config.quant_dtype == "nvfp4"
and quant_config.is_nvfp4_scale_swizzled
):
a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1])
a1q_scale_recv = a1q_scale_recv.view(torch.uint8)
a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv)
a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16)
else:
a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads
a1q_scale_recv = None
a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1])
topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1])
topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1])
return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert self.all2all_manager.moe_alltoall is not None
ep_size = self.all2all_manager.world_size
hidden_size = fused_expert_output.shape[-1]
fused_expert_output = fused_expert_output.view(
ep_size, self.runtime_max_tokens_per_rank, hidden_size
)
combined_output = self.all2all_manager.moe_alltoall.combine(
payload=fused_expert_output,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
)
output.copy_(combined_output)