class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
"""
Compressed tensors scheme for MXFP4 weight-only quantization.
Supports models quantized with the compressed-tensors mxfp4-pack-quantized
format.
MXFP4 format:
- 4-bit float weights (E2M1) packed into uint8
- Per-group E8M0 scales with group_size=32
- No global scale (unlike NVFP4)
"""
def __init__(self):
self.group_size = 32
@classmethod
def get_min_capability(cls) -> int:
return 80
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.params_dtype = params_dtype
# Packed FP4 weights (2 values per byte)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)
# Per-group E8M0 scales
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.group_size,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
prepare_fp4_layer_for_marlin(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_global_scale=None,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)