class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular):
"""Triton-based fused MoE expert implementation."""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
# Whether quantized MOE runs natively, or through
# higher-precision + activation QDQ.
self.quantization_emulation = False
super().__init__(moe_config, quant_config)
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
return current_platform.is_cuda_alike() or current_platform.is_xpu()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
# INT8 requires at least 7.5 (Turing).
device_supports_int8 = (
current_platform.is_cuda()
and current_platform.has_device_capability((7, 5))
)
supported: list[tuple[QuantKey | None, QuantKey | None]] = [(None, None)]
if device_supports_int8:
supported.append((kInt8StaticChannelSym, kInt8DynamicTokenSym))
if current_platform.supports_fp8():
supported += [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8DynamicTokenSym),
(kFp8StaticTensorSym, kFp8StaticTensorSym),
(kFp8StaticTensorSym, kFp8DynamicTensorSym),
]
return (weight_key, activation_key) in supported
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
@staticmethod
def _supports_batch_invariance():
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M, topk, max(activation_out_dim, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
)
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)
if global_num_experts == -1:
global_num_experts = E
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
cache2_dim = self.adjust_N_for_activation(N, activation)
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, cache2_dim)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
_prepare_expert_assignment(
topk_ids,
config,
num_tokens,
top_k_num,
global_num_experts,
expert_map,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)
)
invoke_fused_moe_triton_kernel(
hidden_states,
w1,
intermediate_cache1,
a1q_scale,
self.w1_scale,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w1_bias,
)
# LoRA w13: applied to intermediate_cache1 before activation, using
# hidden_states as the lora_a input. moe_lora_align_block_size is
# called once here and results reused for the w2 LoRA below.
sorted_token_ids_lora = None
expert_ids_lora = None
num_tokens_post_padded_lora = None
token_lora_mapping = None
lora_context = self._lora_context
if lora_context is not None:
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
) = self.apply_w13_lora(
lora_context,
y=intermediate_cache1,
x=hidden_states,
topk_ids=topk_ids,
topk_weights=topk_weights,
expert_map=expert_map,
w1=w1,
w2=w2,
num_tokens=num_tokens,
top_k_num=top_k_num,
)
self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
a2q_scale: torch.Tensor | None = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
quantization_emulation=self.quantization_emulation,
)
invoke_fused_moe_triton_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
self.w2_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w2_bias,
)
# LoRA w2: applied to intermediate_cache3 before moe_sum, using the
# unquantized intermediate_cache2 as the lora_a input. Reuses the
# sorted_token_ids_lora computed above.
if lora_context is not None:
self.apply_w2_lora(
lora_context,
y=intermediate_cache3,
x=intermediate_cache2,
topk_weights=topk_weights,
sorted_token_ids_lora=sorted_token_ids_lora,
expert_ids_lora=expert_ids_lora,
num_tokens_post_padded_lora=num_tokens_post_padded_lora,
token_lora_mapping=token_lora_mapping,
num_tokens=num_tokens,
w1=w1,
w2=w2,
top_k_num=top_k_num,
)
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)