Skip to content

vllm.model_executor.layers.quantization.fbgemm_fp8

logger module-attribute

logger = init_logger(__name__)

FBGEMMFp8Config

Bases: QuantizationConfig

Config class for FBGEMM Fp8.

Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
class FBGEMMFp8Config(QuantizationConfig):
    """Config class for FBGEMM Fp8."""

    def __init__(self, ignore_list: list[str], input_scale_ub: float):
        super().__init__()
        self.ignore_list = ignore_list if ignore_list else []
        self.input_scale_ub = input_scale_ub

        # For GPUs that lack FP8 hardware support, we can leverage the Marlin
        # kernel for fast weight-only FP8 quantization
        self.use_marlin = not current_platform.has_device_capability(89)
        self.fp8_linear = Fp8LinearOp()

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "fbgemm_fp8"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16, torch.float16]

    @classmethod
    def get_min_capability(cls) -> int:
        return 80

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
        ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
        input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
        return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)

    def get_quant_method(self, layer: torch.nn.Module,
                         prefix: str) -> Optional["QuantizeMethodBase"]:
        if isinstance(layer, LinearBase):
            if is_layer_skipped(prefix=prefix,
                                ignored_layers=self.ignore_list,
                                fused_mapping=self.packed_modules_mapping):
                return UnquantizedLinearMethod()
            return FBGEMMFp8LinearMethod(self)
        return None

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp()

ignore_list instance-attribute

ignore_list = ignore_list if ignore_list else []

input_scale_ub instance-attribute

input_scale_ub = input_scale_ub

use_marlin instance-attribute

use_marlin = not has_device_capability(89)

__init__

__init__(ignore_list: list[str], input_scale_ub: float)
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
def __init__(self, ignore_list: list[str], input_scale_ub: float):
    super().__init__()
    self.ignore_list = ignore_list if ignore_list else []
    self.input_scale_ub = input_scale_ub

    # For GPUs that lack FP8 hardware support, we can leverage the Marlin
    # kernel for fast weight-only FP8 quantization
    self.use_marlin = not current_platform.has_device_capability(89)
    self.fp8_linear = Fp8LinearOp()

from_config classmethod

from_config(config: dict[str, Any]) -> FBGEMMFp8Config
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config":
    ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
    input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
    return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)

get_config_filenames classmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
@classmethod
def get_config_filenames(cls) -> list[str]:
    return []

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
@classmethod
def get_min_capability(cls) -> int:
    return 80

get_name classmethod

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
@classmethod
def get_name(cls) -> QuantizationMethods:
    return "fbgemm_fp8"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[QuantizeMethodBase]
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
def get_quant_method(self, layer: torch.nn.Module,
                     prefix: str) -> Optional["QuantizeMethodBase"]:
    if isinstance(layer, LinearBase):
        if is_layer_skipped(prefix=prefix,
                            ignored_layers=self.ignore_list,
                            fused_mapping=self.packed_modules_mapping):
            return UnquantizedLinearMethod()
        return FBGEMMFp8LinearMethod(self)
    return None

get_supported_act_dtypes classmethod

get_supported_act_dtypes() -> list[dtype]
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
    return [torch.bfloat16, torch.float16]

FBGEMMFp8LinearMethod

Bases: LinearMethodBase

Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
class FBGEMMFp8LinearMethod(LinearMethodBase):

    def __init__(self, quant_config: FBGEMMFp8Config):
        self.quant_config = quant_config
        self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
        self.out_dtype = torch.get_default_dtype()

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        maybe_create_device_identity()
        weight_loader = extra_weight_attrs.get("weight_loader")
        del input_size, output_size
        output_size_per_partition = sum(output_partition_sizes)

        layer.logical_widths = output_partition_sizes

        layer.input_size_per_partition = input_size_per_partition
        layer.output_size_per_partition = output_size_per_partition
        layer.orig_dtype = params_dtype

        # WEIGHT
        weight = ModelWeightParameter(data=torch.empty(
            output_size_per_partition,
            input_size_per_partition,
            dtype=torch.float8_e4m3fn),
                                      input_dim=1,
                                      output_dim=0,
                                      weight_loader=weight_loader)
        layer.register_parameter("weight", weight)

        # WEIGHT SCALE
        weight_scale = ChannelQuantScaleParameter(data=torch.empty(
            (sum(output_partition_sizes), 1), dtype=torch.float32),
                                                  output_dim=0,
                                                  weight_loader=weight_loader)
        weight_scale[:] = torch.finfo(torch.float32).min
        layer.register_parameter("weight_scale", weight_scale)

        # INPUT SCALE UPPER BOUND
        input_scale_ub = torch.nn.Parameter(torch.tensor(
            (self.quant_config.input_scale_ub), dtype=torch.float32),
                                            requires_grad=False)
        layer.input_scale_ub = input_scale_ub

    def process_weights_after_loading(self, layer: Module) -> None:
        # required by torch.compile
        layer.weight_scale = Parameter(layer.weight_scale.data,
                                       requires_grad=False)
        layer.weight = Parameter(layer.weight.data, requires_grad=False)

        weight = layer.weight

        if current_platform.is_fp8_fnuz():
            weight, weight_scale, input_scale = \
                normalize_e4m3fn_to_e4m3fnuz(
                    weight=weight,
                    weight_scale=layer.weight_scale,
                    input_scale=None)
            if input_scale is not None:
                layer.input_scale = Parameter(input_scale, requires_grad=False)
            layer.weight_scale = Parameter(weight_scale, requires_grad=False)

        layer.weight = Parameter(weight.t(), requires_grad=False)
        if self.quant_config.use_marlin:
            prepare_fp8_layer_for_marlin(layer)
            # Activations not quantized for marlin.
            del layer.input_scale_ub

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:

        if self.quant_config.use_marlin:
            return apply_fp8_marlin_linear(
                input=x,
                weight=layer.weight,
                weight_scale=layer.weight_scale,
                workspace=layer.workspace,
                size_n=layer.output_size_per_partition,
                size_k=layer.input_size_per_partition,
                bias=bias)

        return self.fp8_linear.apply(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     out_dtype=self.out_dtype,
                                     input_scale=None,
                                     input_scale_ub=layer.input_scale_ub,
                                     bias=bias)

fp8_linear instance-attribute

fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)

out_dtype instance-attribute

out_dtype = get_default_dtype()

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(quant_config: FBGEMMFp8Config)
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
def __init__(self, quant_config: FBGEMMFp8Config):
    self.quant_config = quant_config
    self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
    self.out_dtype = torch.get_default_dtype()

apply

apply(
    layer: Module, x: Tensor, bias: Optional[Tensor] = None
) -> Tensor
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
def apply(self,
          layer: torch.nn.Module,
          x: torch.Tensor,
          bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    if self.quant_config.use_marlin:
        return apply_fp8_marlin_linear(
            input=x,
            weight=layer.weight,
            weight_scale=layer.weight_scale,
            workspace=layer.workspace,
            size_n=layer.output_size_per_partition,
            size_k=layer.input_size_per_partition,
            bias=bias)

    return self.fp8_linear.apply(input=x,
                                 weight=layer.weight,
                                 weight_scale=layer.weight_scale,
                                 out_dtype=self.out_dtype,
                                 input_scale=None,
                                 input_scale_ub=layer.input_scale_ub,
                                 bias=bias)

create_weights

create_weights(
    layer: Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: dtype,
    **extra_weight_attrs,
)
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
def create_weights(
    self,
    layer: torch.nn.Module,
    input_size_per_partition: int,
    output_partition_sizes: list[int],
    input_size: int,
    output_size: int,
    params_dtype: torch.dtype,
    **extra_weight_attrs,
):
    maybe_create_device_identity()
    weight_loader = extra_weight_attrs.get("weight_loader")
    del input_size, output_size
    output_size_per_partition = sum(output_partition_sizes)

    layer.logical_widths = output_partition_sizes

    layer.input_size_per_partition = input_size_per_partition
    layer.output_size_per_partition = output_size_per_partition
    layer.orig_dtype = params_dtype

    # WEIGHT
    weight = ModelWeightParameter(data=torch.empty(
        output_size_per_partition,
        input_size_per_partition,
        dtype=torch.float8_e4m3fn),
                                  input_dim=1,
                                  output_dim=0,
                                  weight_loader=weight_loader)
    layer.register_parameter("weight", weight)

    # WEIGHT SCALE
    weight_scale = ChannelQuantScaleParameter(data=torch.empty(
        (sum(output_partition_sizes), 1), dtype=torch.float32),
                                              output_dim=0,
                                              weight_loader=weight_loader)
    weight_scale[:] = torch.finfo(torch.float32).min
    layer.register_parameter("weight_scale", weight_scale)

    # INPUT SCALE UPPER BOUND
    input_scale_ub = torch.nn.Parameter(torch.tensor(
        (self.quant_config.input_scale_ub), dtype=torch.float32),
                                        requires_grad=False)
    layer.input_scale_ub = input_scale_ub

process_weights_after_loading

process_weights_after_loading(layer: Module) -> None
Source code in vllm/model_executor/layers/quantization/fbgemm_fp8.py
def process_weights_after_loading(self, layer: Module) -> None:
    # required by torch.compile
    layer.weight_scale = Parameter(layer.weight_scale.data,
                                   requires_grad=False)
    layer.weight = Parameter(layer.weight.data, requires_grad=False)

    weight = layer.weight

    if current_platform.is_fp8_fnuz():
        weight, weight_scale, input_scale = \
            normalize_e4m3fn_to_e4m3fnuz(
                weight=weight,
                weight_scale=layer.weight_scale,
                input_scale=None)
        if input_scale is not None:
            layer.input_scale = Parameter(input_scale, requires_grad=False)
        layer.weight_scale = Parameter(weight_scale, requires_grad=False)

    layer.weight = Parameter(weight.t(), requires_grad=False)
    if self.quant_config.use_marlin:
        prepare_fp8_layer_for_marlin(layer)
        # Activations not quantized for marlin.
        del layer.input_scale_ub