Skip to content

vllm.model_executor.layers.quantization.neuron_quant

SUPPORTED_QUANT_DTYPE_LIST module-attribute

SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']

AlwaysSupportedDtypes

Bases: list

Source code in vllm/model_executor/layers/quantization/neuron_quant.py
class AlwaysSupportedDtypes(list):

    def __contains__(self, item):
        return True

__contains__

__contains__(item)
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
def __contains__(self, item):
    return True

NeuronQuantConfig

Bases: QuantizationConfig

Int8 Quantization Config class for Neuron Backend.

Source code in vllm/model_executor/layers/quantization/neuron_quant.py
class NeuronQuantConfig(QuantizationConfig):
    """Int8 Quantization Config class for Neuron Backend."""

    def __init__(
        self,
        dequant_dtype: str = "f16",
        quantize_method: str = "vector_dynamic",
    ) -> None:
        super().__init__()
        self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
        if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
            raise ValueError(
                f"Neuron quantization datatype {self.quant_dtype} is not valid,"
                f" the quantization datatype should match one of the below "
                f"types {SUPPORTED_QUANT_DTYPE_LIST}")
        self.dequant_dtype = dequant_dtype
        self.quantize_method = quantize_method

    def get_name(self) -> QuantizationMethods:
        return "neuron_quant"

    def get_supported_act_dtypes(self) -> list[str]:
        # Neuron implements custom handling logic for quantization support
        return AlwaysSupportedDtypes()

    @classmethod
    def get_min_capability(cls) -> int:
        raise NotImplementedError(
            "This function should not be called with Neuron Backend")

    @staticmethod
    def get_config_filenames() -> list[str]:
        return []

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
        quantize_method = cls.get_from_keys(config, ["quantize_method"])
        dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
        return cls(dequant_dtype=dequant_dtype,
                   quantize_method=quantize_method)

    def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
        if find_spec("transformers_neuronx") is not None:
            return self.get_quantization_config()
        else:
            raise NotImplementedError(
                "Neuron Quantization is only supported through"
                " transformers_neuronx.")

    def get_quantization_config(self):
        from transformers_neuronx.config import QuantizationConfig
        return QuantizationConfig(quant_dtype=self.quant_dtype,
                                  dequant_dtype=self.dequant_dtype,
                                  quantize_method=self.quantize_method)

dequant_dtype instance-attribute

dequant_dtype = dequant_dtype

quant_dtype instance-attribute

quant_dtype = getenv('NEURON_QUANT_DTYPE', 's8')

quantize_method instance-attribute

quantize_method = quantize_method

__init__

__init__(
    dequant_dtype: str = "f16",
    quantize_method: str = "vector_dynamic",
) -> None
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
def __init__(
    self,
    dequant_dtype: str = "f16",
    quantize_method: str = "vector_dynamic",
) -> None:
    super().__init__()
    self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
    if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
        raise ValueError(
            f"Neuron quantization datatype {self.quant_dtype} is not valid,"
            f" the quantization datatype should match one of the below "
            f"types {SUPPORTED_QUANT_DTYPE_LIST}")
    self.dequant_dtype = dequant_dtype
    self.quantize_method = quantize_method

from_config classmethod

from_config(config: dict[str, Any]) -> NeuronQuantConfig
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
@classmethod
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
    quantize_method = cls.get_from_keys(config, ["quantize_method"])
    dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
    return cls(dequant_dtype=dequant_dtype,
               quantize_method=quantize_method)

get_config_filenames staticmethod

get_config_filenames() -> list[str]
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
@staticmethod
def get_config_filenames() -> list[str]:
    return []

get_min_capability classmethod

get_min_capability() -> int
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
@classmethod
def get_min_capability(cls) -> int:
    raise NotImplementedError(
        "This function should not be called with Neuron Backend")

get_name

get_name() -> QuantizationMethods
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
def get_name(self) -> QuantizationMethods:
    return "neuron_quant"

get_quant_method

get_quant_method(
    layer: Module, prefix: str
) -> Optional[Any]
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
    if find_spec("transformers_neuronx") is not None:
        return self.get_quantization_config()
    else:
        raise NotImplementedError(
            "Neuron Quantization is only supported through"
            " transformers_neuronx.")

get_quantization_config

get_quantization_config()
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
def get_quantization_config(self):
    from transformers_neuronx.config import QuantizationConfig
    return QuantizationConfig(quant_dtype=self.quant_dtype,
                              dequant_dtype=self.dequant_dtype,
                              quantize_method=self.quantize_method)

get_supported_act_dtypes

get_supported_act_dtypes() -> list[str]
Source code in vllm/model_executor/layers/quantization/neuron_quant.py
def get_supported_act_dtypes(self) -> list[str]:
    # Neuron implements custom handling logic for quantization support
    return AlwaysSupportedDtypes()