class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations"
)
per_tensor = (
self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR
)
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN
)
if not (per_tensor or per_channel):
assert self.weight_quant.strategy == QuantizationStrategy.BLOCK
self.weight_block_size = self.weight_quant.block_structure
assert self.weight_quant.dynamic is not None
else:
self.weight_block_size = None
self.block_quant = self.weight_block_size is not None
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales and per_channel:
raise ValueError(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
and not self.block_quant
)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
self.weight_quant, self.input_quant
)
self.use_cutlass = not self.block_quant and (
quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant)
or self.is_fp8_w8a8_sm100
)
self.disable_expert_map = False
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
params_dtype = torch.float8_e4m3fn
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
tp_size = get_tensor_model_parallel_world_size()
block_n, block_k = (
self.weight_block_size[0],
self.weight_block_size[1],
)
# NOTE: To ensure proper alignment of the block-wise quantization
# scales, the output_size of the weights for both the gate and up
# layers must be divisible by block_n.
# Required by column parallel or enabling merged weights
if intermediate_size_per_partition % block_n != 0:
raise ValueError(
f"The output_size of gate's and up's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_n = {block_n}."
)
if tp_size > 1 and intermediate_size_per_partition % block_k != 0:
# Required by row parallel
raise ValueError(
f"The input_size of down's weight = "
f"{intermediate_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}."
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
# Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
elif self.weight_quant.strategy == QuantizationStrategy.BLOCK:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
(hidden_size + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
(hidden_size + block_n - 1) // block_n,
(intermediate_size_per_partition + block_k - 1) // block_k,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
)
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts,
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
if self.use_cutlass:
assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full(
(layer.local_num_experts,),
layer.hidden_size,
device=device,
dtype=torch.int64,
)
self.ab_strides2 = torch.full(
(layer.local_num_experts,),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(layer.local_num_experts,),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64,
)
if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale):
layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale
)
if expert_weight_is_col_major(layer.w2_weight_scale):
layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale
)
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
assert self.moe_quant_config is not None
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
)
experts: FusedMoEPermuteExpertsUnpermute
num_dispatchers = prepare_finalize.num_dispatchers()
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
self.moe.num_local_experts,
num_dispatchers,
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
self.moe.in_dtype,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
quant_config=self.moe_quant_config,
)
self.disable_expert_map = (
num_dispatchers > 1 or not experts.supports_expert_map()
)
return experts
# triton path
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
if self.use_marlin:
return None
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_channel_quant,
block_shape=layer.weight_block_size,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet."
)
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
#
# Note: the order here is important. self.fused_experts can override
# cutlass fp8 or fused_experts but not marlin or rocm.
#
if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert self.fused_experts is None
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
None,
None,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=scalar_types.float8_e4m3fn.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
workspace=layer.workspace,
)
elif self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
rocm_aiter_fused_experts,
)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
assert self.fused_experts is None
return rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
elif self.fused_experts is not None:
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
)
# cutlass path
elif self.use_cutlass:
assert self.moe_quant_config is not None
# small-batch fallback on SM100
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
quant_config=self.moe_quant_config,
)
else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8,
)
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return cutlass_moe_fp8(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_config=self.moe_quant_config,
activation=activation,
global_num_experts=global_num_experts,
expert_map=None if self.disable_expert_map else expert_map,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
assert per_act_token == per_channel_quant
assert self.moe_quant_config is not None
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)